From 41b9965c14c684cc9aa1bfb2d3decba0b58a1373 Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Wed, 11 Nov 2020 20:59:44 +0800 Subject: [PATCH 01/32] rebase code from upstream --- taichi/codegen/codegen_llvm.cpp | 6 +++--- taichi/ir/type.h | 18 +++++++++++++++--- taichi/transforms/type_check.cpp | 7 ++----- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index eef98ff495ba7..70185050a6c95 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -328,7 +328,7 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { auto from_size = 0; if (from->is()) { // TODO: replace 32 with a customizable type - from_size = 32; + from_size = data_type_size(from->cast()->get_compute_type()); } else { from_size = data_type_size(from); } @@ -1132,8 +1132,8 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { auto bit_end = builder->CreateAdd(bit_offset, tlctx->get_constant(cit->get_num_bits())); auto left = builder->CreateSub(tlctx->get_constant(32), bit_end); - auto right = builder->CreateAdd(tlctx->get_constant(32), - tlctx->get_constant(-cit->get_num_bits())); + auto right = builder->CreateSub(tlctx->get_constant(32), + tlctx->get_constant(cit->get_num_bits())); auto step1 = builder->CreateShl(bit_level_container, left); llvm::Value *step2 = nullptr; if (cit->get_is_signed()) diff --git a/taichi/ir/type.h b/taichi/ir/type.h index ab93d85c0190c..8e75a9b0229b6 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -167,11 +167,22 @@ class VectorType : public Type { class CustomIntType : public Type { public: CustomIntType(int num_bits, bool is_signed) - : num_bits_(num_bits), is_signed_(is_signed) { + : compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) { + TI_ASSERT(num_bits <= 32); + compute_type = is_signed ? new PrimitiveType(PrimitiveTypeID::i32) : + new PrimitiveType(PrimitiveTypeID::u32); + } + + ~CustomIntType() override { + delete compute_type; } std::string to_string() const override; + Type* get_compute_type() { + return compute_type; + } + int get_num_bits() const { return num_bits_; } @@ -183,8 +194,9 @@ class CustomIntType : public Type { private: // TODO(type): for now we can uniformly use i32 as the "compute_type". It may // be a good idea to make "compute_type" also customizable. - int num_bits_; - bool is_signed_; + Type* compute_type{nullptr}; + int num_bits_{32}; + bool is_signed_{true}; }; class BitStructType : public Type { diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index df6fe8cdbb355..2282867868a2f 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -268,11 +268,8 @@ class TypeCheck : public IRVisitor { if (stmt->lhs->ret_type != stmt->rhs->ret_type) { auto promote_custom_int_type = [&](Stmt *stmt, Stmt *hs) { - if (hs->ret_type->is()) { - if (hs->ret_type->cast()->get_is_signed()) - return insert_type_cast_before(stmt, hs, get_data_type()); - else - return insert_type_cast_before(stmt, hs, get_data_type()); + if (auto cit = hs->ret_type->cast()) { + return insert_type_cast_before(stmt, hs, cit->get_compute_type()); } return hs; }; From 0b8fd8bc654b52b8c7f3bbd536a1d925e6fb7f34 Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Wed, 11 Nov 2020 22:03:06 +0800 Subject: [PATCH 02/32] add llvm_ptr_type and data type bits --- taichi/codegen/codegen_llvm.cpp | 34 +++++++++++++++++++++++++++++---- taichi/codegen/codegen_llvm.h | 2 ++ taichi/lang_util.cpp | 4 ++++ taichi/lang_util.h | 1 + 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 70185050a6c95..134856693b60c 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -618,6 +618,31 @@ llvm::Type *CodeGenLLVM::llvm_type(DataType dt) { return nullptr; } +llvm::Type *CodeGenLLVM::llvm_ptr_type(DataType dt) { + if (dt->is_primitive(PrimitiveTypeID::i8) || + dt->is_primitive(PrimitiveTypeID::u8)) { + return llvm::Type::getInt8PtrTy(*llvm_context); + } else if (dt->is_primitive(PrimitiveTypeID::i16) || + dt->is_primitive(PrimitiveTypeID::u16)) { + return llvm::Type::getInt16PtrTy(*llvm_context); + } else if (dt->is_primitive(PrimitiveTypeID::i32) || + dt->is_primitive(PrimitiveTypeID::u32)) { + return llvm::Type::getInt32PtrTy(*llvm_context); + } else if (dt->is_primitive(PrimitiveTypeID::i64) || + dt->is_primitive(PrimitiveTypeID::u64)) { + return llvm::Type::getInt64PtrTy(*llvm_context); + } else if (dt->is_primitive(PrimitiveTypeID::u1)) { + return llvm::Type::getInt1PtrTy(*llvm_context); + } else if (dt->is_primitive(PrimitiveTypeID::f32)) { + return llvm::Type::getFloatPtrTy(*llvm_context); + } else if (dt->is_primitive(PrimitiveTypeID::f64)) { + return llvm::Type::getDoublePtrTy(*llvm_context); + } else { + TI_NOT_IMPLEMENTED; + } + return nullptr; +} + void CodeGenLLVM::visit(TernaryOpStmt *stmt) { TI_ASSERT(stmt->op_type == TernaryOpType::select); llvm_val[stmt] = builder->CreateSelect( @@ -1108,7 +1133,7 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { builder->CreateCall( get_runtime_function("set_partial_bits_b32"), {builder->CreateBitCast(byte_ptr, - llvm::Type::getInt32PtrTy(*llvm_context)), + llvm_ptr_type(cit->get_compute_type())), bit_offset, tlctx->get_constant(cit->get_num_bits()), llvm_val[stmt->data]}); } else { @@ -1125,14 +1150,15 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { llvm::Value *byte_ptr, *bit_offset; read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); auto bit_level_container = builder->CreateLoad(builder->CreateBitCast( - byte_ptr, llvm::Type::getInt32PtrTy(*llvm_context))); + byte_ptr, llvm_ptr_type(cit->get_compute_type()))); // 2. bit shifting // first left shift `32 - (offset + num_bits)` // then right shift `32 - num_bits` + auto compute_type_size = data_type_size(cit->get_compute_type()) * 8; auto bit_end = builder->CreateAdd(bit_offset, tlctx->get_constant(cit->get_num_bits())); - auto left = builder->CreateSub(tlctx->get_constant(32), bit_end); - auto right = builder->CreateSub(tlctx->get_constant(32), + auto left = builder->CreateSub(tlctx->get_constant(compute_type_size), bit_end); + auto right = builder->CreateSub(tlctx->get_constant(compute_type_size), tlctx->get_constant(cit->get_num_bits())); auto step1 = builder->CreateShl(bit_level_container, left); llvm::Value *step2 = nullptr; diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index e93c6faec44f0..bf32b5f3b9253 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -148,6 +148,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Type *llvm_type(DataType dt); + llvm::Type *llvm_ptr_type(DataType dt); + void visit(Block *stmt_list) override; void visit(AllocaStmt *stmt) override; diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index 06bfa22057564..f43698bbe0c45 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -125,6 +125,10 @@ int data_type_size(DataType t) { } } +int data_type_bits(DataType t) { + return data_type_size(t) * 8; +} + std::string data_type_short_name(DataType t) { if (!t->is()) { return t->to_string(); diff --git a/taichi/lang_util.h b/taichi/lang_util.h index fdaf3eac35b81..f6634f38472fc 100644 --- a/taichi/lang_util.h +++ b/taichi/lang_util.h @@ -345,6 +345,7 @@ std::string make_list(const std::vector &data, } int data_type_size(DataType t); +int data_type_bits(DataType t); DataType promoted_type(DataType a, DataType b); extern std::string compiled_lib_dir; From b1b5e564fdfcc176e324aa344243c1fc13c2af3f Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Fri, 13 Nov 2020 12:03:49 +0800 Subject: [PATCH 03/32] add some comments --- taichi/codegen/codegen_llvm.cpp | 6 +++--- taichi/ir/type.h | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 134856693b60c..c12da8984c293 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1152,9 +1152,9 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { auto bit_level_container = builder->CreateLoad(builder->CreateBitCast( byte_ptr, llvm_ptr_type(cit->get_compute_type()))); // 2. bit shifting - // first left shift `32 - (offset + num_bits)` - // then right shift `32 - num_bits` - auto compute_type_size = data_type_size(cit->get_compute_type()) * 8; + // first left shift `compute_type_size(like 32, 64, ...) - (offset + num_bits)` + // then right shift `compute_type_size - num_bits` + auto compute_type_size = data_type_bits(cit->get_compute_type()); auto bit_end = builder->CreateAdd(bit_offset, tlctx->get_constant(cit->get_num_bits())); auto left = builder->CreateSub(tlctx->get_constant(compute_type_size), bit_end); diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 8e75a9b0229b6..b732e537ed3e9 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -168,6 +168,8 @@ class CustomIntType : public Type { public: CustomIntType(int num_bits, bool is_signed) : compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) { + // TODO(type): support customizable compute_type + // and should we expose it to users? TI_ASSERT(num_bits <= 32); compute_type = is_signed ? new PrimitiveType(PrimitiveTypeID::i32) : new PrimitiveType(PrimitiveTypeID::u32); From 100ea350a4ffbeb79563f0ed4157abe47f24c119 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Thu, 12 Nov 2020 23:12:38 -0500 Subject: [PATCH 04/32] [skip ci] enforce code format --- taichi/codegen/codegen_llvm.cpp | 21 +++++++++++---------- taichi/ir/type.h | 8 ++++---- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index c12da8984c293..2bf6c43438068 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -328,7 +328,8 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { auto from_size = 0; if (from->is()) { // TODO: replace 32 with a customizable type - from_size = data_type_size(from->cast()->get_compute_type()); + from_size = + data_type_size(from->cast()->get_compute_type()); } else { from_size = data_type_size(from); } @@ -1130,12 +1131,11 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { auto cit = ptr_type->get_pointee_type()->as(); llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr; read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); - builder->CreateCall( - get_runtime_function("set_partial_bits_b32"), - {builder->CreateBitCast(byte_ptr, - llvm_ptr_type(cit->get_compute_type())), - bit_offset, tlctx->get_constant(cit->get_num_bits()), - llvm_val[stmt->data]}); + builder->CreateCall(get_runtime_function("set_partial_bits_b32"), + {builder->CreateBitCast( + byte_ptr, llvm_ptr_type(cit->get_compute_type())), + bit_offset, tlctx->get_constant(cit->get_num_bits()), + llvm_val[stmt->data]}); } else { builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); } @@ -1152,12 +1152,13 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { auto bit_level_container = builder->CreateLoad(builder->CreateBitCast( byte_ptr, llvm_ptr_type(cit->get_compute_type()))); // 2. bit shifting - // first left shift `compute_type_size(like 32, 64, ...) - (offset + num_bits)` - // then right shift `compute_type_size - num_bits` + // first left shift `compute_type_size(like 32, 64, ...) - (offset + + // num_bits)` then right shift `compute_type_size - num_bits` auto compute_type_size = data_type_bits(cit->get_compute_type()); auto bit_end = builder->CreateAdd(bit_offset, tlctx->get_constant(cit->get_num_bits())); - auto left = builder->CreateSub(tlctx->get_constant(compute_type_size), bit_end); + auto left = + builder->CreateSub(tlctx->get_constant(compute_type_size), bit_end); auto right = builder->CreateSub(tlctx->get_constant(compute_type_size), tlctx->get_constant(cit->get_num_bits())); auto step1 = builder->CreateShl(bit_level_container, left); diff --git a/taichi/ir/type.h b/taichi/ir/type.h index b732e537ed3e9..d41fdfba38224 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -171,8 +171,8 @@ class CustomIntType : public Type { // TODO(type): support customizable compute_type // and should we expose it to users? TI_ASSERT(num_bits <= 32); - compute_type = is_signed ? new PrimitiveType(PrimitiveTypeID::i32) : - new PrimitiveType(PrimitiveTypeID::u32); + compute_type = is_signed ? new PrimitiveType(PrimitiveTypeID::i32) + : new PrimitiveType(PrimitiveTypeID::u32); } ~CustomIntType() override { @@ -181,7 +181,7 @@ class CustomIntType : public Type { std::string to_string() const override; - Type* get_compute_type() { + Type *get_compute_type() { return compute_type; } @@ -196,7 +196,7 @@ class CustomIntType : public Type { private: // TODO(type): for now we can uniformly use i32 as the "compute_type". It may // be a good idea to make "compute_type" also customizable. - Type* compute_type{nullptr}; + Type *compute_type{nullptr}; int num_bits_{32}; bool is_signed_{true}; }; From 33999819da0fa8f44c73e2caeca18a31e2f78a14 Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Fri, 13 Nov 2020 16:35:30 +0800 Subject: [PATCH 05/32] use type factory --- taichi/ir/type.cpp | 14 ++++++++++++++ taichi/ir/type.h | 9 +-------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index cd57929cb54c1..ea17be9b0504f 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -101,6 +101,20 @@ std::string CustomIntType::to_string() const { return fmt::format("c{}{}", is_signed_ ? 'i' : 'u', num_bits_); } +CustomIntType::CustomIntType(int num_bits, bool is_signed): + compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) { + // TODO(type): support customizable compute_type + // and should we expose it to users? + TI_ASSERT(num_bits <= 32); + if (is_signed) { + compute_type = + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i32); + } else { + compute_type = + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u32); + } +} + std::string BitStructType::to_string() const { std::string str = "bs("; int num_members = (int)member_bit_offsets_.size(); diff --git a/taichi/ir/type.h b/taichi/ir/type.h index d41fdfba38224..053a316f5817f 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -166,14 +166,7 @@ class VectorType : public Type { class CustomIntType : public Type { public: - CustomIntType(int num_bits, bool is_signed) - : compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) { - // TODO(type): support customizable compute_type - // and should we expose it to users? - TI_ASSERT(num_bits <= 32); - compute_type = is_signed ? new PrimitiveType(PrimitiveTypeID::i32) - : new PrimitiveType(PrimitiveTypeID::u32); - } + CustomIntType(int num_bits, bool is_signed); ~CustomIntType() override { delete compute_type; From e98f4504707a1332f2f1d48f34c402f30e8ab637 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Fri, 13 Nov 2020 03:36:13 -0500 Subject: [PATCH 06/32] [skip ci] enforce code format --- taichi/ir/type.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index ea17be9b0504f..21ad16aeea656 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -101,8 +101,8 @@ std::string CustomIntType::to_string() const { return fmt::format("c{}{}", is_signed_ ? 'i' : 'u', num_bits_); } -CustomIntType::CustomIntType(int num_bits, bool is_signed): - compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) { +CustomIntType::CustomIntType(int num_bits, bool is_signed) + : compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) { // TODO(type): support customizable compute_type // and should we expose it to users? TI_ASSERT(num_bits <= 32); From 9ff847647cac6e64a7a5fbee705562adbfd81d26 Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Sat, 14 Nov 2020 15:01:35 +0800 Subject: [PATCH 07/32] update runtime --- taichi/codegen/codegen_llvm.cpp | 13 +++++++++---- taichi/runtime/llvm/runtime.cpp | 30 +++++++++++++++++++----------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 2bf6c43438068..dee27b9d739ca 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -933,8 +933,8 @@ void CodeGenLLVM::visit(KernelReturnStmt *stmt) { TI_NOT_IMPLEMENTED } else { auto intermediate_bits = 0; - if (stmt->value->ret_type->is()) { - intermediate_bits = 32; + if (auto cit = stmt->value->ret_type->cast()) { + intermediate_bits = data_type_bits(cit->get_compute_type()); } else { intermediate_bits = tlctx->get_data_type(stmt->value->ret_type)->getPrimitiveSizeInBits(); @@ -1131,11 +1131,14 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { auto cit = ptr_type->get_pointee_type()->as(); llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr; read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); - builder->CreateCall(get_runtime_function("set_partial_bits_b32"), + auto runtime_func_name = fmt::format("set_partial_bits_b{}", + data_type_bits(cit->get_compute_type())); + builder->CreateCall(get_runtime_function(runtime_func_name), {builder->CreateBitCast( byte_ptr, llvm_ptr_type(cit->get_compute_type())), bit_offset, tlctx->get_constant(cit->get_num_bits()), - llvm_val[stmt->data]}); + builder->CreateIntCast(llvm_val[stmt->data], + llvm_type(cit->get_compute_type()), cit->get_is_signed())}); } else { builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); } @@ -1161,6 +1164,8 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { builder->CreateSub(tlctx->get_constant(compute_type_size), bit_end); auto right = builder->CreateSub(tlctx->get_constant(compute_type_size), tlctx->get_constant(cit->get_num_bits())); + left = builder->CreateIntCast(left, llvm_type(cit->get_compute_type()), cit->get_is_signed()); + right = builder->CreateIntCast(right, llvm_type(cit->get_compute_type()), cit->get_is_signed()); auto step1 = builder->CreateShl(bit_level_container, left); llvm::Value *step2 = nullptr; if (cit->get_is_signed()) diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index 2f1d4e2f31f8d..e13335f9e4a10 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -81,9 +81,11 @@ using float32 = float; using float64 = double; using i8 = int8; +using i16 = int16; using i32 = int32; using i64 = int64; using u8 = uint8; +using u16 = uint16; using u32 = uint32; using u64 = uint64; using f32 = float32; @@ -1551,17 +1553,23 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) { #include "internal_functions.h" -void set_partial_bits_b32(u32 *ptr, u32 offset, u32 bits, u32 value) { - u32 mask = ((((u32)1 << bits) - 1) << offset); - u32 new_value = 0; - u32 old_value = *ptr; - do { - old_value = *ptr; - new_value = (old_value & (~mask)) | (value << offset); - } while (!__atomic_compare_exchange(ptr, &old_value, &new_value, true, - std::memory_order::memory_order_seq_cst, - std::memory_order::memory_order_seq_cst)); -} +#define DEFINE_SET_PARTIAL_BITS(N) \ +void set_partial_bits_b##N(u##N* ptr, u32 offset, u32 bits, u##N value) { \ + u##N mask = ((((u##N)1 << bits) - 1) << offset); \ + u##N new_value = 0; \ + u##N old_value = *ptr; \ + do { \ + old_value = *ptr; \ + new_value = (old_value & (~mask)) | (u##N)(value << offset); \ + } while (!__atomic_compare_exchange(ptr, &old_value, &new_value, true, \ + std::memory_order::memory_order_seq_cst, \ + std::memory_order::memory_order_seq_cst)); \ +} + +DEFINE_SET_PARTIAL_BITS(8); +DEFINE_SET_PARTIAL_BITS(16); +DEFINE_SET_PARTIAL_BITS(32); +DEFINE_SET_PARTIAL_BITS(64); } #endif From 04cb2f35811b0efaa75ee165b25933b65580abb1 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Sat, 14 Nov 2020 02:02:49 -0500 Subject: [PATCH 08/32] [skip ci] enforce code format --- taichi/codegen/codegen_llvm.cpp | 24 ++++++++++++++---------- taichi/runtime/llvm/runtime.cpp | 23 ++++++++++++----------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index dee27b9d739ca..1582cb4b0758c 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1131,14 +1131,16 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { auto cit = ptr_type->get_pointee_type()->as(); llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr; read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); - auto runtime_func_name = fmt::format("set_partial_bits_b{}", - data_type_bits(cit->get_compute_type())); - builder->CreateCall(get_runtime_function(runtime_func_name), - {builder->CreateBitCast( - byte_ptr, llvm_ptr_type(cit->get_compute_type())), - bit_offset, tlctx->get_constant(cit->get_num_bits()), - builder->CreateIntCast(llvm_val[stmt->data], - llvm_type(cit->get_compute_type()), cit->get_is_signed())}); + auto runtime_func_name = fmt::format( + "set_partial_bits_b{}", data_type_bits(cit->get_compute_type())); + builder->CreateCall( + get_runtime_function(runtime_func_name), + {builder->CreateBitCast(byte_ptr, + llvm_ptr_type(cit->get_compute_type())), + bit_offset, tlctx->get_constant(cit->get_num_bits()), + builder->CreateIntCast(llvm_val[stmt->data], + llvm_type(cit->get_compute_type()), + cit->get_is_signed())}); } else { builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); } @@ -1164,8 +1166,10 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { builder->CreateSub(tlctx->get_constant(compute_type_size), bit_end); auto right = builder->CreateSub(tlctx->get_constant(compute_type_size), tlctx->get_constant(cit->get_num_bits())); - left = builder->CreateIntCast(left, llvm_type(cit->get_compute_type()), cit->get_is_signed()); - right = builder->CreateIntCast(right, llvm_type(cit->get_compute_type()), cit->get_is_signed()); + left = builder->CreateIntCast(left, llvm_type(cit->get_compute_type()), + cit->get_is_signed()); + right = builder->CreateIntCast(right, llvm_type(cit->get_compute_type()), + cit->get_is_signed()); auto step1 = builder->CreateShl(bit_level_container, left); llvm::Value *step2 = nullptr; if (cit->get_is_signed()) diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index e13335f9e4a10..6dac414f0becc 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -1554,17 +1554,18 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) { #include "internal_functions.h" #define DEFINE_SET_PARTIAL_BITS(N) \ -void set_partial_bits_b##N(u##N* ptr, u32 offset, u32 bits, u##N value) { \ - u##N mask = ((((u##N)1 << bits) - 1) << offset); \ - u##N new_value = 0; \ - u##N old_value = *ptr; \ - do { \ - old_value = *ptr; \ - new_value = (old_value & (~mask)) | (u##N)(value << offset); \ - } while (!__atomic_compare_exchange(ptr, &old_value, &new_value, true, \ - std::memory_order::memory_order_seq_cst, \ - std::memory_order::memory_order_seq_cst)); \ -} + void set_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, u##N value) { \ + u##N mask = ((((u##N)1 << bits) - 1) << offset); \ + u##N new_value = 0; \ + u##N old_value = *ptr; \ + do { \ + old_value = *ptr; \ + new_value = (old_value & (~mask)) | (u##N)(value << offset); \ + } while ( \ + !__atomic_compare_exchange(ptr, &old_value, &new_value, true, \ + std::memory_order::memory_order_seq_cst, \ + std::memory_order::memory_order_seq_cst)); \ + } DEFINE_SET_PARTIAL_BITS(8); DEFINE_SET_PARTIAL_BITS(16); From 15ffdfa3807c0ec43d9ac2e8335eec17c296ddb5 Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Sun, 15 Nov 2020 18:16:41 +0800 Subject: [PATCH 09/32] modify bit_ptr struct and rebase --- taichi/codegen/codegen_llvm.cpp | 57 +++++++++++++++++--------------- taichi/codegen/codegen_llvm.h | 3 +- taichi/ir/type.cpp | 30 +++++++++++++++++ taichi/ir/type.h | 2 ++ taichi/ir/type_factory.cpp | 11 ++++++ taichi/ir/type_factory.h | 4 +++ taichi/llvm/llvm_codegen_utils.h | 7 +++- taichi/python/export_lang.cpp | 2 ++ taichi/runtime/llvm/runtime.cpp | 41 ++++++++++++++++------- 9 files changed, 116 insertions(+), 41 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 1582cb4b0758c..6a5fc63ed9417 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1129,18 +1129,14 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { auto ptr_type = stmt->ptr->ret_type->as(); if (ptr_type->is_bit_pointer()) { auto cit = ptr_type->get_pointee_type()->as(); - llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr; - read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); - auto runtime_func_name = fmt::format( - "set_partial_bits_b{}", data_type_bits(cit->get_compute_type())); - builder->CreateCall( - get_runtime_function(runtime_func_name), - {builder->CreateBitCast(byte_ptr, - llvm_ptr_type(cit->get_compute_type())), - bit_offset, tlctx->get_constant(cit->get_num_bits()), - builder->CreateIntCast(llvm_val[stmt->data], - llvm_type(cit->get_compute_type()), - cit->get_is_signed())}); + llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr, *physical_type_size = nullptr; + read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset, + physical_type_size); + builder->CreateCall(get_runtime_function("set_partial_bits"), + {builder->CreateBitCast( + byte_ptr, llvm::Type::getInt32PtrTy(*llvm_context)), + bit_offset, tlctx->get_constant(cit->get_num_bits()), + llvm_val[stmt->data], physical_type_size}); } else { builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); } @@ -1152,24 +1148,22 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { if (stmt->ptr->ret_type->as()->is_bit_pointer()) { auto cit = stmt->ret_type->as(); // 1. load bit pointer - llvm::Value *byte_ptr, *bit_offset; - read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); + llvm::Value *byte_ptr, *bit_offset, *physical_type_size; + read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset, + physical_type_size); auto bit_level_container = builder->CreateLoad(builder->CreateBitCast( - byte_ptr, llvm_ptr_type(cit->get_compute_type()))); + byte_ptr, llvm::Type::getInt32PtrTy(*llvm_context))); // 2. bit shifting - // first left shift `compute_type_size(like 32, 64, ...) - (offset + + // first left shift `compute_type_size(like 32, 64, ...) - (offset +z // num_bits)` then right shift `compute_type_size - num_bits` - auto compute_type_size = data_type_bits(cit->get_compute_type()); auto bit_end = builder->CreateAdd(bit_offset, tlctx->get_constant(cit->get_num_bits())); auto left = - builder->CreateSub(tlctx->get_constant(compute_type_size), bit_end); - auto right = builder->CreateSub(tlctx->get_constant(compute_type_size), + builder->CreateSub(physical_type_size, bit_end); + auto right = builder->CreateSub(physical_type_size, tlctx->get_constant(cit->get_num_bits())); - left = builder->CreateIntCast(left, llvm_type(cit->get_compute_type()), - cit->get_is_signed()); - right = builder->CreateIntCast(right, llvm_type(cit->get_compute_type()), - cit->get_is_signed()); + left = builder->CreateIntCast(left, physical_type_size->getType(), false); + right = builder->CreateIntCast(right, physical_type_size->getType(), false); auto step1 = builder->CreateShl(bit_level_container, left); llvm::Value *step2 = nullptr; if (cit->get_is_signed()) @@ -1272,15 +1266,18 @@ void CodeGenLLVM::visit(LinearizeStmt *stmt) { void CodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED} llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base, - llvm::Value *bit_offset) { + llvm::Value *bit_offset, + int num_bits) { // 1. create a bit pointer struct // struct bit_pointer { // i8* byte_ptr; // i32 offset; + // i32 physical_type_size; // }; auto struct_type = llvm::StructType::get( *llvm_context, {llvm::Type::getInt8PtrTy(*llvm_context), - llvm::Type::getInt32Ty(*llvm_context)}); + llvm::Type::getInt32Ty(*llvm_context), + llvm::Type::getInt32Ty(*llvm_context)}); // 2. alloca the bit pointer struct auto bit_ptr_struct = create_entry_block_alloca(struct_type); // 3. store `input_ptr` into `bit_ptr_struct` @@ -1293,6 +1290,10 @@ llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base, builder->CreateStore( bit_offset, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0), tlctx->get_constant(1)})); + // 5. store `physical_type` in `bit_ptr_struct` + builder->CreateStore( + tlctx->get_constant(num_bits), builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0), + tlctx->get_constant(2)})); return bit_ptr_struct; } @@ -1320,7 +1321,8 @@ void CodeGenLLVM::visit(SNodeLookupStmt *stmt) { snode->dt.get_ptr()->as()->get_element_num_bits(); auto offset = tlctx->get_constant(element_num_bits); offset = builder->CreateMul(offset, llvm_val[stmt->input_index]); - llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_snode], offset); + llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_snode], offset, + data_type_bits(snode->dt.get_ptr()->as()->get_element_type())); } else { TI_INFO(snode_type_name(snode->type)); TI_NOT_IMPLEMENTED @@ -1335,7 +1337,8 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { auto bit_offset = bit_struct->get_member_bit_offset( stmt->input_snode->child_id(stmt->output_snode)); auto offset = tlctx->get_constant(bit_offset); - llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset); + llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset, + data_type_bits(bit_struct->get_physical_type())); } else { auto ch = create_call(stmt->output_snode->get_ch_from_parent_func_name(), {builder->CreateBitCast( diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index bf32b5f3b9253..ad2b8d5a48655 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -209,7 +209,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(IntegerOffsetStmt *stmt) override; llvm::Value *create_bit_ptr_struct(llvm::Value *byte_ptr_base, - llvm::Value *bit_offset); + llvm::Value *bit_offset, + int num_bits=32); void visit(SNodeLookupStmt *stmt) override; diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 21ad16aeea656..c02a43f05b671 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -115,6 +115,36 @@ CustomIntType::CustomIntType(int num_bits, bool is_signed) } } + +CustomIntType::CustomIntType(int compute_type_bits, int num_bits, bool is_signed) + : compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) { + if (compute_type_bits == 32) { + if (is_signed) { + compute_type = + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i32); + } else { + compute_type = + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u32); + } + } else if (compute_type_bits == 16) { + if (is_signed) { + compute_type = + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i16); + } else { + compute_type = + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u16); + } + } else if (compute_type_bits == 8) { + if (is_signed) { + compute_type = + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i8); + } else { + compute_type = + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u8); + } + } +} + std::string BitStructType::to_string() const { std::string str = "bs("; int num_members = (int)member_bit_offsets_.size(); diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 053a316f5817f..42fe2368a14fd 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -168,6 +168,8 @@ class CustomIntType : public Type { public: CustomIntType(int num_bits, bool is_signed); + CustomIntType(int compute_type_bits, int numBits, bool isSigned); + ~CustomIntType() override { delete compute_type; } diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 2d1b7c12cd122..72f8fb8821ccc 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -46,6 +46,17 @@ Type *TypeFactory::get_custom_int_type(int num_bits, bool is_signed) { return custom_int_types_[key].get(); } +Type *TypeFactory::_get_custom_int_type(int compute_type_bits, + int num_bits, + bool is_signed) { + auto key = std::make_pair(num_bits, is_signed); + if (custom_int_types_with_compute_types_.find(key) == custom_int_types_with_compute_types_.end()) { + custom_int_types_with_compute_types_[key] = + std::make_unique(compute_type_bits, num_bits, is_signed); + } + return custom_int_types_with_compute_types_[key].get(); +} + Type *TypeFactory::get_bit_struct_type(PrimitiveType *physical_type, std::vector member_types, std::vector member_bit_offsets) { diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index aecd54b870f13..ea122e2071d88 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -23,6 +23,8 @@ class TypeFactory { Type *get_custom_int_type(int num_bits, bool is_signed); + Type *_get_custom_int_type(int compute_type_bits, int num_bits, bool is_signed); + Type *get_bit_struct_type(PrimitiveType *physical_type, std::vector member_types, std::vector member_bit_offsets); @@ -49,6 +51,8 @@ class TypeFactory { // TODO: use unordered map std::map, std::unique_ptr> custom_int_types_; + std::map, std::unique_ptr> custom_int_types_with_compute_types_; + // TODO: avoid duplication std::vector> bit_struct_types_; diff --git a/taichi/llvm/llvm_codegen_utils.h b/taichi/llvm/llvm_codegen_utils.h index bd9c1e717d334..f1c92f0c819cd 100644 --- a/taichi/llvm/llvm_codegen_utils.h +++ b/taichi/llvm/llvm_codegen_utils.h @@ -128,7 +128,8 @@ class LLVMModuleBuilder { void read_bit_pointer(llvm::Value *ptr, llvm::Value *&byte_ptr, - llvm::Value *&bit_offset) { + llvm::Value *&bit_offset, + llvm::Value *&physical_type) { // 1. load byte pointer auto byte_ptr_in_bit_struct = builder->CreateGEP( ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}); @@ -140,6 +141,10 @@ class LLVMModuleBuilder { ptr, {tlctx->get_constant(0), tlctx->get_constant(1)}); bit_offset = builder->CreateLoad(bit_offset_in_bit_struct); TI_ASSERT(bit_offset->getType()->isIntegerTy(32)); + + auto physical_type_bit_struct = builder->CreateGEP(ptr, {tlctx->get_constant(0), tlctx->get_constant(2)}); + physical_type = builder->CreateLoad(physical_type_bit_struct); + TI_ASSERT(bit_offset->getType()->isIntegerTy(32)); } }; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 5ceaba48cbda9..8721940ab1ae4 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -728,6 +728,8 @@ void export_lang(py::module &m) { // TypeFactory on Python-scope pointer destruction. py::class_(m, "TypeFactory") .def("get_custom_int_type", &TypeFactory::get_custom_int_type, + py::return_value_policy::reference) + .def("_get_custom_int_type", &TypeFactory::_get_custom_int_type, py::return_value_policy::reference); m.def("get_type_factory_instance", TypeFactory::get_instance, diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index 6dac414f0becc..b0e068ba26375 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -1554,23 +1554,40 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) { #include "internal_functions.h" #define DEFINE_SET_PARTIAL_BITS(N) \ - void set_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, u##N value) { \ - u##N mask = ((((u##N)1 << bits) - 1) << offset); \ - u##N new_value = 0; \ - u##N old_value = *ptr; \ - do { \ - old_value = *ptr; \ - new_value = (old_value & (~mask)) | (u##N)(value << offset); \ - } while ( \ - !__atomic_compare_exchange(ptr, &old_value, &new_value, true, \ - std::memory_order::memory_order_seq_cst, \ - std::memory_order::memory_order_seq_cst)); \ - } +void set_partial_bits_b##N(u##N* ptr, u32 offset, u32 bits, u##N value) { \ + u##N mask = ((((u##N)1 << bits) - 1) << offset); \ + u##N new_value = 0; \ + u##N old_value = *ptr; \ + do { \ + old_value = *ptr; \ + new_value = (old_value & (~mask)) | (value << offset); \ + } while (!__atomic_compare_exchange(ptr, &old_value, &new_value, true, \ + std::memory_order::memory_order_seq_cst, \ + std::memory_order::memory_order_seq_cst)); \ +} DEFINE_SET_PARTIAL_BITS(8); DEFINE_SET_PARTIAL_BITS(16); DEFINE_SET_PARTIAL_BITS(32); DEFINE_SET_PARTIAL_BITS(64); + +#define CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, N) \ + else if (n == N) { \ + set_partial_bits_b##N((u##N*)ptr, offset, bits, (u##N)value); \ + } + + +void set_partial_bits(u32* ptr, u32 offset, u32 bits, u32 value, u32 n) { + if (false) { + } + CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, 8) + CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, 16) + CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, 32) + CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, 64) + else { + assert(false); + } +} } #endif From a44bd21ccd0219decf102b234d5e1db961f26d00 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Sun, 15 Nov 2020 05:33:56 -0500 Subject: [PATCH 10/32] [skip ci] enforce code format --- taichi/codegen/codegen_llvm.cpp | 35 ++++++++++++++++++-------------- taichi/codegen/codegen_llvm.h | 2 +- taichi/ir/type.cpp | 5 +++-- taichi/ir/type_factory.cpp | 3 ++- taichi/ir/type_factory.h | 7 +++++-- taichi/llvm/llvm_codegen_utils.h | 3 ++- taichi/runtime/llvm/runtime.cpp | 32 ++++++++++++++--------------- 7 files changed, 49 insertions(+), 38 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 6a5fc63ed9417..9d167a3281e17 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1129,14 +1129,16 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { auto ptr_type = stmt->ptr->ret_type->as(); if (ptr_type->is_bit_pointer()) { auto cit = ptr_type->get_pointee_type()->as(); - llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr, *physical_type_size = nullptr; + llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr, + *physical_type_size = nullptr; read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset, physical_type_size); - builder->CreateCall(get_runtime_function("set_partial_bits"), - {builder->CreateBitCast( - byte_ptr, llvm::Type::getInt32PtrTy(*llvm_context)), - bit_offset, tlctx->get_constant(cit->get_num_bits()), - llvm_val[stmt->data], physical_type_size}); + builder->CreateCall( + get_runtime_function("set_partial_bits"), + {builder->CreateBitCast(byte_ptr, + llvm::Type::getInt32PtrTy(*llvm_context)), + bit_offset, tlctx->get_constant(cit->get_num_bits()), + llvm_val[stmt->data], physical_type_size}); } else { builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); } @@ -1158,8 +1160,7 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { // num_bits)` then right shift `compute_type_size - num_bits` auto bit_end = builder->CreateAdd(bit_offset, tlctx->get_constant(cit->get_num_bits())); - auto left = - builder->CreateSub(physical_type_size, bit_end); + auto left = builder->CreateSub(physical_type_size, bit_end); auto right = builder->CreateSub(physical_type_size, tlctx->get_constant(cit->get_num_bits())); left = builder->CreateIntCast(left, physical_type_size->getType(), false); @@ -1277,7 +1278,7 @@ llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base, auto struct_type = llvm::StructType::get( *llvm_context, {llvm::Type::getInt8PtrTy(*llvm_context), llvm::Type::getInt32Ty(*llvm_context), - llvm::Type::getInt32Ty(*llvm_context)}); + llvm::Type::getInt32Ty(*llvm_context)}); // 2. alloca the bit pointer struct auto bit_ptr_struct = create_entry_block_alloca(struct_type); // 3. store `input_ptr` into `bit_ptr_struct` @@ -1292,8 +1293,9 @@ llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base, tlctx->get_constant(1)})); // 5. store `physical_type` in `bit_ptr_struct` builder->CreateStore( - tlctx->get_constant(num_bits), builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0), - tlctx->get_constant(2)})); + tlctx->get_constant(num_bits), + builder->CreateGEP(bit_ptr_struct, + {tlctx->get_constant(0), tlctx->get_constant(2)})); return bit_ptr_struct; } @@ -1321,8 +1323,10 @@ void CodeGenLLVM::visit(SNodeLookupStmt *stmt) { snode->dt.get_ptr()->as()->get_element_num_bits(); auto offset = tlctx->get_constant(element_num_bits); offset = builder->CreateMul(offset, llvm_val[stmt->input_index]); - llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_snode], offset, - data_type_bits(snode->dt.get_ptr()->as()->get_element_type())); + llvm_val[stmt] = create_bit_ptr_struct( + llvm_val[stmt->input_snode], offset, + data_type_bits( + snode->dt.get_ptr()->as()->get_element_type())); } else { TI_INFO(snode_type_name(snode->type)); TI_NOT_IMPLEMENTED @@ -1337,8 +1341,9 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { auto bit_offset = bit_struct->get_member_bit_offset( stmt->input_snode->child_id(stmt->output_snode)); auto offset = tlctx->get_constant(bit_offset); - llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset, - data_type_bits(bit_struct->get_physical_type())); + llvm_val[stmt] = + create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset, + data_type_bits(bit_struct->get_physical_type())); } else { auto ch = create_call(stmt->output_snode->get_ch_from_parent_func_name(), {builder->CreateBitCast( diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index ad2b8d5a48655..d89498f71c7cb 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -210,7 +210,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *create_bit_ptr_struct(llvm::Value *byte_ptr_base, llvm::Value *bit_offset, - int num_bits=32); + int num_bits = 32); void visit(SNodeLookupStmt *stmt) override; diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index c02a43f05b671..a850ae2f1deb0 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -115,8 +115,9 @@ CustomIntType::CustomIntType(int num_bits, bool is_signed) } } - -CustomIntType::CustomIntType(int compute_type_bits, int num_bits, bool is_signed) +CustomIntType::CustomIntType(int compute_type_bits, + int num_bits, + bool is_signed) : compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) { if (compute_type_bits == 32) { if (is_signed) { diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 72f8fb8821ccc..f8d2eb85dfe6c 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -50,7 +50,8 @@ Type *TypeFactory::_get_custom_int_type(int compute_type_bits, int num_bits, bool is_signed) { auto key = std::make_pair(num_bits, is_signed); - if (custom_int_types_with_compute_types_.find(key) == custom_int_types_with_compute_types_.end()) { + if (custom_int_types_with_compute_types_.find(key) == + custom_int_types_with_compute_types_.end()) { custom_int_types_with_compute_types_[key] = std::make_unique(compute_type_bits, num_bits, is_signed); } diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index ea122e2071d88..e39838c7c030f 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -23,7 +23,9 @@ class TypeFactory { Type *get_custom_int_type(int num_bits, bool is_signed); - Type *_get_custom_int_type(int compute_type_bits, int num_bits, bool is_signed); + Type *_get_custom_int_type(int compute_type_bits, + int num_bits, + bool is_signed); Type *get_bit_struct_type(PrimitiveType *physical_type, std::vector member_types, @@ -51,7 +53,8 @@ class TypeFactory { // TODO: use unordered map std::map, std::unique_ptr> custom_int_types_; - std::map, std::unique_ptr> custom_int_types_with_compute_types_; + std::map, std::unique_ptr> + custom_int_types_with_compute_types_; // TODO: avoid duplication std::vector> bit_struct_types_; diff --git a/taichi/llvm/llvm_codegen_utils.h b/taichi/llvm/llvm_codegen_utils.h index f1c92f0c819cd..3247c29d88edf 100644 --- a/taichi/llvm/llvm_codegen_utils.h +++ b/taichi/llvm/llvm_codegen_utils.h @@ -142,7 +142,8 @@ class LLVMModuleBuilder { bit_offset = builder->CreateLoad(bit_offset_in_bit_struct); TI_ASSERT(bit_offset->getType()->isIntegerTy(32)); - auto physical_type_bit_struct = builder->CreateGEP(ptr, {tlctx->get_constant(0), tlctx->get_constant(2)}); + auto physical_type_bit_struct = builder->CreateGEP( + ptr, {tlctx->get_constant(0), tlctx->get_constant(2)}); physical_type = builder->CreateLoad(physical_type_bit_struct); TI_ASSERT(bit_offset->getType()->isIntegerTy(32)); } diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index b0e068ba26375..cbdafba1377e8 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -1554,30 +1554,30 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) { #include "internal_functions.h" #define DEFINE_SET_PARTIAL_BITS(N) \ -void set_partial_bits_b##N(u##N* ptr, u32 offset, u32 bits, u##N value) { \ - u##N mask = ((((u##N)1 << bits) - 1) << offset); \ - u##N new_value = 0; \ - u##N old_value = *ptr; \ - do { \ - old_value = *ptr; \ - new_value = (old_value & (~mask)) | (value << offset); \ - } while (!__atomic_compare_exchange(ptr, &old_value, &new_value, true, \ - std::memory_order::memory_order_seq_cst, \ - std::memory_order::memory_order_seq_cst)); \ -} + void set_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, u##N value) { \ + u##N mask = ((((u##N)1 << bits) - 1) << offset); \ + u##N new_value = 0; \ + u##N old_value = *ptr; \ + do { \ + old_value = *ptr; \ + new_value = (old_value & (~mask)) | (value << offset); \ + } while ( \ + !__atomic_compare_exchange(ptr, &old_value, &new_value, true, \ + std::memory_order::memory_order_seq_cst, \ + std::memory_order::memory_order_seq_cst)); \ + } DEFINE_SET_PARTIAL_BITS(8); DEFINE_SET_PARTIAL_BITS(16); DEFINE_SET_PARTIAL_BITS(32); DEFINE_SET_PARTIAL_BITS(64); -#define CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, N) \ - else if (n == N) { \ - set_partial_bits_b##N((u##N*)ptr, offset, bits, (u##N)value); \ +#define CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, N) \ + else if (n == N) { \ + set_partial_bits_b##N((u##N *)ptr, offset, bits, (u##N)value); \ } - -void set_partial_bits(u32* ptr, u32 offset, u32 bits, u32 value, u32 n) { +void set_partial_bits(u32 *ptr, u32 offset, u32 bits, u32 value, u32 n) { if (false) { } CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, 8) From 6d82d3900cb9c8db531c216a2ff016a35aaea0ca Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Mon, 16 Nov 2020 14:46:52 +0800 Subject: [PATCH 11/32] fix bit_array --- taichi/codegen/codegen_llvm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 9d167a3281e17..6b48b49dcc1d6 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1326,7 +1326,7 @@ void CodeGenLLVM::visit(SNodeLookupStmt *stmt) { llvm_val[stmt] = create_bit_ptr_struct( llvm_val[stmt->input_snode], offset, data_type_bits( - snode->dt.get_ptr()->as()->get_element_type())); + snode->dt.get_ptr()->as()->get_physical_type())); } else { TI_INFO(snode_type_name(snode->type)); TI_NOT_IMPLEMENTED From fab966e7184bd0b40fa70e75d9685912e9fc650f Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Mon, 16 Nov 2020 20:59:08 +0800 Subject: [PATCH 12/32] use runtime to do global loading --- taichi/codegen/codegen_llvm.cpp | 25 +++-------- taichi/ir/type.cpp | 57 ++++++++++++++---------- taichi/ir/type.h | 17 +++++--- taichi/runtime/llvm/runtime.cpp | 77 +++++++++++++++++++++++++-------- 4 files changed, 109 insertions(+), 67 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 6b48b49dcc1d6..d37e70ece6019 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1153,25 +1153,12 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { llvm::Value *byte_ptr, *bit_offset, *physical_type_size; read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset, physical_type_size); - auto bit_level_container = builder->CreateLoad(builder->CreateBitCast( - byte_ptr, llvm::Type::getInt32PtrTy(*llvm_context))); - // 2. bit shifting - // first left shift `compute_type_size(like 32, 64, ...) - (offset +z - // num_bits)` then right shift `compute_type_size - num_bits` - auto bit_end = builder->CreateAdd(bit_offset, - tlctx->get_constant(cit->get_num_bits())); - auto left = builder->CreateSub(physical_type_size, bit_end); - auto right = builder->CreateSub(physical_type_size, - tlctx->get_constant(cit->get_num_bits())); - left = builder->CreateIntCast(left, physical_type_size->getType(), false); - right = builder->CreateIntCast(right, physical_type_size->getType(), false); - auto step1 = builder->CreateShl(bit_level_container, left); - llvm::Value *step2 = nullptr; - if (cit->get_is_signed()) - step2 = builder->CreateAShr(step1, right); - else - step2 = builder->CreateLShr(step1, right); - llvm_val[stmt] = step2; + auto tmp = builder->CreateCall( + get_runtime_function("load_partial_bits"), + {byte_ptr, + bit_offset, tlctx->get_constant(cit->get_num_bits()), + physical_type_size, tlctx->get_constant((uint32)cit->get_is_signed())}); + llvm_val[stmt] = builder->CreateIntCast(tmp, llvm_type(cit->get_compute_type()), cit->get_is_signed()); } else { llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), llvm_val[stmt->ptr]); diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index a850ae2f1deb0..e5b526c93b2fa 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -115,34 +115,43 @@ CustomIntType::CustomIntType(int num_bits, bool is_signed) } } +#define SET_COMPUTE_TYPE(n, N) \ + else if (n == N) { \ + if (is_signed) \ + type_id = PrimitiveTypeID::i##N; \ + else \ + type_id = PrimitiveTypeID::u##N; \ + } + CustomIntType::CustomIntType(int compute_type_bits, int num_bits, bool is_signed) : compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) { - if (compute_type_bits == 32) { - if (is_signed) { - compute_type = - TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i32); - } else { - compute_type = - TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u32); - } - } else if (compute_type_bits == 16) { - if (is_signed) { - compute_type = - TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i16); - } else { - compute_type = - TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u16); - } - } else if (compute_type_bits == 8) { - if (is_signed) { - compute_type = - TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i8); - } else { - compute_type = - TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u8); - } + auto type_id = PrimitiveTypeID::unknown; + if (false) { + } + SET_COMPUTE_TYPE(compute_type_bits, 64) + SET_COMPUTE_TYPE(compute_type_bits, 32) + SET_COMPUTE_TYPE(compute_type_bits, 16) + SET_COMPUTE_TYPE(compute_type_bits, 8) + else { + TI_NOT_IMPLEMENTED + } + compute_type = TypeFactory::get_instance().get_primitive_type(type_id); +} + +BitStructType::BitStructType(PrimitiveType *physical_type, + std::vector member_types, + std::vector member_bit_offsets) + : physical_type_(physical_type), + member_types_(member_types), + member_bit_offsets_(member_bit_offsets) { + TI_ASSERT(member_types_.size() == member_bit_offsets_.size()); + int physical_type_bits = data_type_bits(physical_type); + for (auto i = 0; i < member_types_.size(); ++i) { + auto bits_end = member_types_[i]->as()->get_num_bits() + + member_bit_offsets_[i]; + TI_ASSERT(physical_type_bits >= bits_end) } } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 42fe2368a14fd..bf13eaa82c513 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -200,12 +200,17 @@ class BitStructType : public Type { public: BitStructType(PrimitiveType *physical_type, std::vector member_types, - std::vector member_bit_offsets) - : physical_type_(physical_type), - member_types_(member_types), - member_bit_offsets_(member_bit_offsets) { - TI_ASSERT(member_types_.size() == member_bit_offsets_.size()); - } + std::vector member_bit_offsets); +// : physical_type_(physical_type), +// member_types_(member_types), +// member_bit_offsets_(member_bit_offsets) { +// TI_ASSERT(member_types_.size() == member_bit_offsets_.size()); +// int physical_type_bits = data_type_bits(physical_type); +// for (auto i = 0; i < member_types_.size(); ++i) { +// auto bits_end = member_types_[i]->as()->get_num_bits() + member_bit_offsets_[i]; +// TI_ASSERT(physical_type_bits >= bits_end) +// } +// } std::string to_string() const override; diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index cbdafba1377e8..f4e87816407fe 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -1554,40 +1554,81 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) { #include "internal_functions.h" #define DEFINE_SET_PARTIAL_BITS(N) \ - void set_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, u##N value) { \ - u##N mask = ((((u##N)1 << bits) - 1) << offset); \ - u##N new_value = 0; \ - u##N old_value = *ptr; \ - do { \ - old_value = *ptr; \ - new_value = (old_value & (~mask)) | (value << offset); \ - } while ( \ - !__atomic_compare_exchange(ptr, &old_value, &new_value, true, \ - std::memory_order::memory_order_seq_cst, \ - std::memory_order::memory_order_seq_cst)); \ - } +void set_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, u##N value) { \ + u##N mask = ((((u##N)1 << bits) - 1) << offset); \ + u##N new_value = 0; \ + u##N old_value = *ptr; \ + do { \ + old_value = *ptr; \ + new_value = (old_value & (~mask)) | (value << offset); \ + } while ( \ + !__atomic_compare_exchange(ptr, &old_value, &new_value, true, \ + std::memory_order::memory_order_seq_cst, \ + std::memory_order::memory_order_seq_cst)); \ +} DEFINE_SET_PARTIAL_BITS(8); DEFINE_SET_PARTIAL_BITS(16); DEFINE_SET_PARTIAL_BITS(32); DEFINE_SET_PARTIAL_BITS(64); -#define CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, N) \ + + +void set_partial_bits(u32 *ptr, u32 offset, u32 bits, u32 value, u32 n) { +#define CALL_SET_PARTIAL_BITS_FUNC(N) \ else if (n == N) { \ set_partial_bits_b##N((u##N *)ptr, offset, bits, (u##N)value); \ } -void set_partial_bits(u32 *ptr, u32 offset, u32 bits, u32 value, u32 n) { if (false) { } - CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, 8) - CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, 16) - CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, 32) - CALL_SET_PARTIAL_BITS_FUNC(ptr, offset, bits, value, n, 64) + CALL_SET_PARTIAL_BITS_FUNC(8) + CALL_SET_PARTIAL_BITS_FUNC(16) + CALL_SET_PARTIAL_BITS_FUNC(32) + CALL_SET_PARTIAL_BITS_FUNC(64) else { assert(false); } } + +#define DEFINE_LOAD_PARTIAL_BITS(s, N) \ +int64 load_partial_bits_##s##N(i8 *ptr, u32 offset, u32 bits) { \ + s##N value = *(s##N*)ptr; \ + value = (value << (N - offset - bits)); \ + value = (value >> (N - bits)); \ + return value; \ +} + +DEFINE_LOAD_PARTIAL_BITS(i, 8) +DEFINE_LOAD_PARTIAL_BITS(i, 16) +DEFINE_LOAD_PARTIAL_BITS(i, 32) +DEFINE_LOAD_PARTIAL_BITS(i, 64) + +DEFINE_LOAD_PARTIAL_BITS(u, 8) +DEFINE_LOAD_PARTIAL_BITS(u, 16) +DEFINE_LOAD_PARTIAL_BITS(u, 32) +DEFINE_LOAD_PARTIAL_BITS(u, 64) + +int64 load_partial_bits(i8 *ptr, u32 offset, u32 bits, u32 n, u32 is_signed) { +#define CALL_LOAD_PARTIAL_BITS_FUNC(N) \ + else if (n == N) { \ + if (is_signed) \ + return load_partial_bits_i##N(ptr, offset, bits); \ + else \ + return load_partial_bits_u##N(ptr, offset, bits);\ + } + if (false) { + } + CALL_LOAD_PARTIAL_BITS_FUNC(8) + CALL_LOAD_PARTIAL_BITS_FUNC(16) + CALL_LOAD_PARTIAL_BITS_FUNC(32) + CALL_LOAD_PARTIAL_BITS_FUNC(64) + else { + assert(false); + } + return 0; +} + } #endif From 7bfbcb56ca0d719e7c5176cb92ff879902ed8914 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Mon, 16 Nov 2020 07:59:44 -0500 Subject: [PATCH 13/32] [skip ci] enforce code format --- taichi/codegen/codegen_llvm.cpp | 9 +++--- taichi/ir/type.cpp | 14 ++++----- taichi/ir/type.h | 20 ++++++------ taichi/runtime/llvm/runtime.cpp | 55 ++++++++++++++++----------------- 4 files changed, 47 insertions(+), 51 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index d37e70ece6019..b90fd4e084486 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1155,10 +1155,11 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { physical_type_size); auto tmp = builder->CreateCall( get_runtime_function("load_partial_bits"), - {byte_ptr, - bit_offset, tlctx->get_constant(cit->get_num_bits()), - physical_type_size, tlctx->get_constant((uint32)cit->get_is_signed())}); - llvm_val[stmt] = builder->CreateIntCast(tmp, llvm_type(cit->get_compute_type()), cit->get_is_signed()); + {byte_ptr, bit_offset, tlctx->get_constant(cit->get_num_bits()), + physical_type_size, + tlctx->get_constant((uint32)cit->get_is_signed())}); + llvm_val[stmt] = builder->CreateIntCast( + tmp, llvm_type(cit->get_compute_type()), cit->get_is_signed()); } else { llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), llvm_val[stmt->ptr]); diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index e5b526c93b2fa..061259cc985f2 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -115,11 +115,11 @@ CustomIntType::CustomIntType(int num_bits, bool is_signed) } } -#define SET_COMPUTE_TYPE(n, N) \ - else if (n == N) { \ - if (is_signed) \ +#define SET_COMPUTE_TYPE(n, N) \ + else if (n == N) { \ + if (is_signed) \ type_id = PrimitiveTypeID::i##N; \ - else \ + else \ type_id = PrimitiveTypeID::u##N; \ } @@ -134,10 +134,8 @@ CustomIntType::CustomIntType(int compute_type_bits, SET_COMPUTE_TYPE(compute_type_bits, 32) SET_COMPUTE_TYPE(compute_type_bits, 16) SET_COMPUTE_TYPE(compute_type_bits, 8) - else { - TI_NOT_IMPLEMENTED - } - compute_type = TypeFactory::get_instance().get_primitive_type(type_id); + else {TI_NOT_IMPLEMENTED} compute_type = + TypeFactory::get_instance().get_primitive_type(type_id); } BitStructType::BitStructType(PrimitiveType *physical_type, diff --git a/taichi/ir/type.h b/taichi/ir/type.h index bf13eaa82c513..6c02aa129463e 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -201,16 +201,16 @@ class BitStructType : public Type { BitStructType(PrimitiveType *physical_type, std::vector member_types, std::vector member_bit_offsets); -// : physical_type_(physical_type), -// member_types_(member_types), -// member_bit_offsets_(member_bit_offsets) { -// TI_ASSERT(member_types_.size() == member_bit_offsets_.size()); -// int physical_type_bits = data_type_bits(physical_type); -// for (auto i = 0; i < member_types_.size(); ++i) { -// auto bits_end = member_types_[i]->as()->get_num_bits() + member_bit_offsets_[i]; -// TI_ASSERT(physical_type_bits >= bits_end) -// } -// } + // : physical_type_(physical_type), + // member_types_(member_types), + // member_bit_offsets_(member_bit_offsets) { + // TI_ASSERT(member_types_.size() == member_bit_offsets_.size()); + // int physical_type_bits = data_type_bits(physical_type); + // for (auto i = 0; i < member_types_.size(); ++i) { + // auto bits_end = member_types_[i]->as()->get_num_bits() + // + member_bit_offsets_[i]; TI_ASSERT(physical_type_bits >= bits_end) + // } + // } std::string to_string() const override; diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index f4e87816407fe..d8ae7742ab872 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -1554,28 +1554,26 @@ void stack_push(Ptr stack, size_t max_num_elements, std::size_t element_size) { #include "internal_functions.h" #define DEFINE_SET_PARTIAL_BITS(N) \ -void set_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, u##N value) { \ - u##N mask = ((((u##N)1 << bits) - 1) << offset); \ - u##N new_value = 0; \ - u##N old_value = *ptr; \ - do { \ - old_value = *ptr; \ - new_value = (old_value & (~mask)) | (value << offset); \ - } while ( \ - !__atomic_compare_exchange(ptr, &old_value, &new_value, true, \ - std::memory_order::memory_order_seq_cst, \ - std::memory_order::memory_order_seq_cst)); \ -} + void set_partial_bits_b##N(u##N *ptr, u32 offset, u32 bits, u##N value) { \ + u##N mask = ((((u##N)1 << bits) - 1) << offset); \ + u##N new_value = 0; \ + u##N old_value = *ptr; \ + do { \ + old_value = *ptr; \ + new_value = (old_value & (~mask)) | (value << offset); \ + } while ( \ + !__atomic_compare_exchange(ptr, &old_value, &new_value, true, \ + std::memory_order::memory_order_seq_cst, \ + std::memory_order::memory_order_seq_cst)); \ + } DEFINE_SET_PARTIAL_BITS(8); DEFINE_SET_PARTIAL_BITS(16); DEFINE_SET_PARTIAL_BITS(32); DEFINE_SET_PARTIAL_BITS(64); - - void set_partial_bits(u32 *ptr, u32 offset, u32 bits, u32 value, u32 n) { -#define CALL_SET_PARTIAL_BITS_FUNC(N) \ +#define CALL_SET_PARTIAL_BITS_FUNC(N) \ else if (n == N) { \ set_partial_bits_b##N((u##N *)ptr, offset, bits, (u##N)value); \ } @@ -1591,13 +1589,13 @@ void set_partial_bits(u32 *ptr, u32 offset, u32 bits, u32 value, u32 n) { } } -#define DEFINE_LOAD_PARTIAL_BITS(s, N) \ -int64 load_partial_bits_##s##N(i8 *ptr, u32 offset, u32 bits) { \ - s##N value = *(s##N*)ptr; \ - value = (value << (N - offset - bits)); \ - value = (value >> (N - bits)); \ - return value; \ -} +#define DEFINE_LOAD_PARTIAL_BITS(s, N) \ + int64 load_partial_bits_##s##N(i8 *ptr, u32 offset, u32 bits) { \ + s##N value = *(s##N *)ptr; \ + value = (value << (N - offset - bits)); \ + value = (value >> (N - bits)); \ + return value; \ + } DEFINE_LOAD_PARTIAL_BITS(i, 8) DEFINE_LOAD_PARTIAL_BITS(i, 16) @@ -1610,12 +1608,12 @@ DEFINE_LOAD_PARTIAL_BITS(u, 32) DEFINE_LOAD_PARTIAL_BITS(u, 64) int64 load_partial_bits(i8 *ptr, u32 offset, u32 bits, u32 n, u32 is_signed) { -#define CALL_LOAD_PARTIAL_BITS_FUNC(N) \ - else if (n == N) { \ - if (is_signed) \ - return load_partial_bits_i##N(ptr, offset, bits); \ - else \ - return load_partial_bits_u##N(ptr, offset, bits);\ +#define CALL_LOAD_PARTIAL_BITS_FUNC(N) \ + else if (n == N) { \ + if (is_signed) \ + return load_partial_bits_i##N(ptr, offset, bits); \ + else \ + return load_partial_bits_u##N(ptr, offset, bits); \ } if (false) { } @@ -1628,7 +1626,6 @@ int64 load_partial_bits(i8 *ptr, u32 offset, u32 bits, u32 n, u32 is_signed) { } return 0; } - } #endif From 1f9e5118be5c47597026979e6181b9b1342be60c Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Mon, 16 Nov 2020 21:43:30 +0800 Subject: [PATCH 14/32] add tests --- tests/python/test_bit_struct.py | 71 +++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index 618babdc2643b..7a77a5a13bb70 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -75,3 +75,74 @@ def verify_val(idx: ti.i32): for idx in range(len(test_case_np)): set_val.__wrapped__(idx) verify_val.__wrapped__(idx) + + +@ti.test(arch=ti.cpu, debug=True, cfg_optimization=False) +def test_bit_struct_with_physical_type(): + ci8_5 = ti.type_factory_._get_custom_int_type(8, 5, True) + ci8_3 = ti.type_factory_._get_custom_int_type(8, 3, False) + ci16_4 = ti.type_factory_._get_custom_int_type(16, 4, True) + cu16_12 = ti.type_factory_._get_custom_int_type(16, 12, False) + + ci32_17 = ti.type_factory_._get_custom_int_type(32, 17, True) + ci32_11 = ti.type_factory_._get_custom_int_type(32, 11, True) + cu32_4 = ti.type_factory_._get_custom_int_type(32, 4, False) + + ci64_33 = ti.type_factory_._get_custom_int_type(64, 32, True) + ci64_20 = ti.type_factory_._get_custom_int_type(64, 21, False) + ci64_7 = ti.type_factory_._get_custom_int_type(64, 7, False) + + a = ti.field(dtype=ci8_5) + b = ti.field(dtype=ci8_3) + + c = ti.field(dtype=ci16_4) + d = ti.field(dtype=cu16_12) + + e = ti.field(dtype=ci32_17) + f = ti.field(dtype=ci32_11) + g = ti.field(dtype=cu32_4) + + h = ti.field(dtype=ci64_33) + i = ti.field(dtype=ci64_20) + j = ti.field(dtype=ci64_7) + + ti.root._bit_struct(num_bits=8).place(a, b) + ti.root._bit_struct(num_bits=16).place(c, d) + ti.root._bit_struct(num_bits=32).place(e, f, g) + ti.root._bit_struct(num_bits=64).place(h, i, j) + + test_case_np = np.array( + [[2**4 - 1, 2**3 - 1, 2**3 - 1, 2**12 - 1, 2**16-1, 2**10-1, 2**4-1, 2**31-1, 2**21-1, 2**7-1], + [-2**3, 2**2 - 1, -2**2, 2**11 - 1, -2**15, -2**9, 2**2-1, -2**30, 2**20-1, 2**6-1], + [3, 4, 5, 16, 21, 34, 1, 2020, 456, 123]], + dtype=np.int32) + test_case = ti.Vector.field(10, dtype=ti.i32, shape=len(test_case_np)) + test_case.from_numpy(test_case_np) + + @ti.kernel + def set_val(idx: ti.i32): + a[None] = test_case[idx][0] + b[None] = test_case[idx][1] + c[None] = test_case[idx][2] + d[None] = test_case[idx][3] + e[None] = test_case[idx][4] + f[None] = test_case[idx][5] + g[None] = test_case[idx][6] + h[None] = test_case[idx][7] + i[None] = test_case[idx][8] + + @ti.kernel + def verify_val(idx: ti.i32): + assert a[None] == test_case[idx][0] + assert b[None] == test_case[idx][1] + assert c[None] == test_case[idx][2] + assert d[None] == test_case[idx][3] + assert e[None] == test_case[idx][4] + assert f[None] == test_case[idx][5] + assert g[None] == test_case[idx][6] + assert h[None] == test_case[idx][7] + assert i[None] == test_case[idx][8] + + for idx in range(len(test_case_np)): + set_val(idx) + verify_val(idx) \ No newline at end of file From 9b109f588eca8451f1e6935a435d29bfd8155d1c Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Mon, 16 Nov 2020 08:44:50 -0500 Subject: [PATCH 15/32] [skip ci] enforce code format --- tests/python/test_bit_struct.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index 7a77a5a13bb70..afa3c1ed1afce 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -111,11 +111,15 @@ def test_bit_struct_with_physical_type(): ti.root._bit_struct(num_bits=32).place(e, f, g) ti.root._bit_struct(num_bits=64).place(h, i, j) - test_case_np = np.array( - [[2**4 - 1, 2**3 - 1, 2**3 - 1, 2**12 - 1, 2**16-1, 2**10-1, 2**4-1, 2**31-1, 2**21-1, 2**7-1], - [-2**3, 2**2 - 1, -2**2, 2**11 - 1, -2**15, -2**9, 2**2-1, -2**30, 2**20-1, 2**6-1], - [3, 4, 5, 16, 21, 34, 1, 2020, 456, 123]], - dtype=np.int32) + test_case_np = np.array([[ + 2**4 - 1, 2**3 - 1, 2**3 - 1, 2**12 - 1, 2**16 - 1, 2**10 - 1, + 2**4 - 1, 2**31 - 1, 2**21 - 1, 2**7 - 1 + ], + [ + -2**3, 2**2 - 1, -2**2, 2**11 - 1, -2**15, + -2**9, 2**2 - 1, -2**30, 2**20 - 1, 2**6 - 1 + ], [3, 4, 5, 16, 21, 34, 1, 2020, 456, 123]], + dtype=np.int32) test_case = ti.Vector.field(10, dtype=ti.i32, shape=len(test_case_np)) test_case.from_numpy(test_case_np) @@ -145,4 +149,4 @@ def verify_val(idx: ti.i32): for idx in range(len(test_case_np)): set_val(idx) - verify_val(idx) \ No newline at end of file + verify_val(idx) From 10cd722b36c18c78f4b2d0caf9bd2e7889190367 Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Tue, 17 Nov 2020 10:24:52 +0800 Subject: [PATCH 16/32] add more tetst cases and rebase --- tests/python/test_bit_struct.py | 38 ++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index afa3c1ed1afce..40b1f5678f2af 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -106,21 +106,27 @@ def test_bit_struct_with_physical_type(): i = ti.field(dtype=ci64_20) j = ti.field(dtype=ci64_7) + k = ti.field(dtype=ci16_4) + l = ti.field(dtype=cu16_12) + + m = ti.field(dtype=ci32_17) + n = ti.field(dtype=ci32_11) + o = ti.field(dtype=cu32_4) + ti.root._bit_struct(num_bits=8).place(a, b) ti.root._bit_struct(num_bits=16).place(c, d) ti.root._bit_struct(num_bits=32).place(e, f, g) ti.root._bit_struct(num_bits=64).place(h, i, j) - test_case_np = np.array([[ - 2**4 - 1, 2**3 - 1, 2**3 - 1, 2**12 - 1, 2**16 - 1, 2**10 - 1, - 2**4 - 1, 2**31 - 1, 2**21 - 1, 2**7 - 1 - ], - [ - -2**3, 2**2 - 1, -2**2, 2**11 - 1, -2**15, - -2**9, 2**2 - 1, -2**30, 2**20 - 1, 2**6 - 1 - ], [3, 4, 5, 16, 21, 34, 1, 2020, 456, 123]], - dtype=np.int32) - test_case = ti.Vector.field(10, dtype=ti.i32, shape=len(test_case_np)) + ti.root._bit_struct(num_bits=32).place(k, l) + ti.root._bit_struct(num_bits=64).place(m, n, o) + + test_case_np = np.array( + [[2**4 - 1, 2**3 - 1, 2**3 - 1, 2**12 - 1, 2**16-1, 2**10-1, 2**4-1, 2**31-1, 2**21-1, 2**7-1, 2**3 - 1, 2**12 - 1, 2**16-1, 2**10-1, 2**4-1], + [-2**3, 2**2 - 1, -2**2, 2**11 - 1, -2**15, -2**9, 2**2-1, -2**30, 2**20-1, 2**6-1, -2**2, 2**11 - 1, -2**15, -2**9, 2**2-1,], + [3, 4, 5, 16, 21, 34, 1, 2020, 456, 123, 5, 16, 21, 34, 1]], + dtype=np.int32) + test_case = ti.Vector.field(15, dtype=ti.i32, shape=len(test_case_np)) test_case.from_numpy(test_case_np) @ti.kernel @@ -134,6 +140,12 @@ def set_val(idx: ti.i32): g[None] = test_case[idx][6] h[None] = test_case[idx][7] i[None] = test_case[idx][8] + j[None] = test_case[idx][9] + k[None] = test_case[idx][10] + l[None] = test_case[idx][11] + m[None] = test_case[idx][12] + n[None] = test_case[idx][13] + o[None] = test_case[idx][14] @ti.kernel def verify_val(idx: ti.i32): @@ -146,6 +158,12 @@ def verify_val(idx: ti.i32): assert g[None] == test_case[idx][6] assert h[None] == test_case[idx][7] assert i[None] == test_case[idx][8] + assert j[None] == test_case[idx][9] + assert k[None] == test_case[idx][10] + assert l[None] == test_case[idx][11] + assert m[None] == test_case[idx][12] + assert n[None] == test_case[idx][13] + assert o[None] == test_case[idx][14] for idx in range(len(test_case_np)): set_val(idx) From 45dabf8d0e093f7da0067bfc919c86ab3d13cd7a Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Mon, 16 Nov 2020 21:27:51 -0500 Subject: [PATCH 17/32] [skip ci] enforce code format --- tests/python/test_bit_struct.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index 40b1f5678f2af..5cc45671e5f09 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -122,9 +122,28 @@ def test_bit_struct_with_physical_type(): ti.root._bit_struct(num_bits=64).place(m, n, o) test_case_np = np.array( - [[2**4 - 1, 2**3 - 1, 2**3 - 1, 2**12 - 1, 2**16-1, 2**10-1, 2**4-1, 2**31-1, 2**21-1, 2**7-1, 2**3 - 1, 2**12 - 1, 2**16-1, 2**10-1, 2**4-1], - [-2**3, 2**2 - 1, -2**2, 2**11 - 1, -2**15, -2**9, 2**2-1, -2**30, 2**20-1, 2**6-1, -2**2, 2**11 - 1, -2**15, -2**9, 2**2-1,], - [3, 4, 5, 16, 21, 34, 1, 2020, 456, 123, 5, 16, 21, 34, 1]], + [[ + 2**4 - 1, 2**3 - 1, 2**3 - 1, 2**12 - 1, 2**16 - 1, 2**10 - 1, + 2**4 - 1, 2**31 - 1, 2**21 - 1, 2**7 - 1, 2**3 - 1, 2**12 - 1, + 2**16 - 1, 2**10 - 1, 2**4 - 1 + ], + [ + -2**3, + 2**2 - 1, + -2**2, + 2**11 - 1, + -2**15, + -2**9, + 2**2 - 1, + -2**30, + 2**20 - 1, + 2**6 - 1, + -2**2, + 2**11 - 1, + -2**15, + -2**9, + 2**2 - 1, + ], [3, 4, 5, 16, 21, 34, 1, 2020, 456, 123, 5, 16, 21, 34, 1]], dtype=np.int32) test_case = ti.Vector.field(15, dtype=ti.i32, shape=len(test_case_np)) test_case.from_numpy(test_case_np) From 9d993418044924ade8dba65d8ca9d089b35b9458 Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Wed, 18 Nov 2020 17:46:54 +0800 Subject: [PATCH 18/32] fix bit struct --- taichi/codegen/codegen_llvm.cpp | 84 +++++++++++++------------------- taichi/codegen/codegen_llvm.h | 3 +- taichi/ir/type.cpp | 23 ++++++++- taichi/ir/type.h | 22 +++++---- taichi/llvm/llvm_codegen_utils.h | 8 +-- taichi/runtime/llvm/runtime.cpp | 54 -------------------- taichi/struct/struct_llvm.cpp | 1 + 7 files changed, 69 insertions(+), 126 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index b90fd4e084486..aa9772de51507 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -620,28 +620,7 @@ llvm::Type *CodeGenLLVM::llvm_type(DataType dt) { } llvm::Type *CodeGenLLVM::llvm_ptr_type(DataType dt) { - if (dt->is_primitive(PrimitiveTypeID::i8) || - dt->is_primitive(PrimitiveTypeID::u8)) { - return llvm::Type::getInt8PtrTy(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::i16) || - dt->is_primitive(PrimitiveTypeID::u16)) { - return llvm::Type::getInt16PtrTy(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::i32) || - dt->is_primitive(PrimitiveTypeID::u32)) { - return llvm::Type::getInt32PtrTy(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::i64) || - dt->is_primitive(PrimitiveTypeID::u64)) { - return llvm::Type::getInt64PtrTy(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::u1)) { - return llvm::Type::getInt1PtrTy(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::f32)) { - return llvm::Type::getFloatPtrTy(*llvm_context); - } else if (dt->is_primitive(PrimitiveTypeID::f64)) { - return llvm::Type::getDoublePtrTy(*llvm_context); - } else { - TI_NOT_IMPLEMENTED; - } - return nullptr; + return llvm::PointerType::get(llvm_type(dt), 0); } void CodeGenLLVM::visit(TernaryOpStmt *stmt) { @@ -1129,16 +1108,15 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { auto ptr_type = stmt->ptr->ret_type->as(); if (ptr_type->is_bit_pointer()) { auto cit = ptr_type->get_pointee_type()->as(); - llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr, - *physical_type_size = nullptr; - read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset, - physical_type_size); + llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr; + read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); + auto func_name = fmt::format("set_partial_bits_b{}", data_type_bits(cit->get_physical_type())); builder->CreateCall( - get_runtime_function("set_partial_bits"), + get_runtime_function(func_name), {builder->CreateBitCast(byte_ptr, - llvm::Type::getInt32PtrTy(*llvm_context)), + llvm_ptr_type(cit->get_physical_type())), bit_offset, tlctx->get_constant(cit->get_num_bits()), - llvm_val[stmt->data], physical_type_size}); + builder->CreateIntCast(llvm_val[stmt->data], llvm_type(cit->get_physical_type()), false)}); } else { builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); } @@ -1150,16 +1128,30 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { if (stmt->ptr->ret_type->as()->is_bit_pointer()) { auto cit = stmt->ret_type->as(); // 1. load bit pointer - llvm::Value *byte_ptr, *bit_offset, *physical_type_size; - read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset, - physical_type_size); - auto tmp = builder->CreateCall( - get_runtime_function("load_partial_bits"), - {byte_ptr, bit_offset, tlctx->get_constant(cit->get_num_bits()), - physical_type_size, - tlctx->get_constant((uint32)cit->get_is_signed())}); + llvm::Value *byte_ptr, *bit_offset; + read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); + + auto bit_level_container = builder->CreateLoad(builder->CreateBitCast( + byte_ptr, llvm_ptr_type(cit->get_physical_type()))); + // 2. bit shifting + // first left shift `physical_type - (offset + num_bits)` + // then right shift `physical_type - num_bits` + auto bit_end = builder->CreateAdd(bit_offset, + tlctx->get_constant(cit->get_num_bits())); + auto left = builder->CreateSub(tlctx->get_constant(data_type_bits(cit->get_physical_type())), bit_end); + auto right = builder->CreateSub(tlctx->get_constant(data_type_bits(cit->get_physical_type())), + tlctx->get_constant(cit->get_num_bits())); + left = builder->CreateIntCast(left, bit_level_container->getType(), false); + right = builder->CreateIntCast(right, bit_level_container->getType(), false); + auto step1 = builder->CreateShl(bit_level_container, left); + llvm::Value *step2 = nullptr; + if (cit->get_is_signed()) + step2 = builder->CreateAShr(step1, right); + else + step2 = builder->CreateLShr(step1, right); + llvm_val[stmt] = builder->CreateIntCast( - tmp, llvm_type(cit->get_compute_type()), cit->get_is_signed()); + step2, llvm_type(cit->get_compute_type()), cit->get_is_signed()); } else { llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), llvm_val[stmt->ptr]); @@ -1255,13 +1247,11 @@ void CodeGenLLVM::visit(LinearizeStmt *stmt) { void CodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED} llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base, - llvm::Value *bit_offset, - int num_bits) { + llvm::Value *bit_offset) { // 1. create a bit pointer struct // struct bit_pointer { // i8* byte_ptr; // i32 offset; - // i32 physical_type_size; // }; auto struct_type = llvm::StructType::get( *llvm_context, {llvm::Type::getInt8PtrTy(*llvm_context), @@ -1279,11 +1269,6 @@ llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base, builder->CreateStore( bit_offset, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0), tlctx->get_constant(1)})); - // 5. store `physical_type` in `bit_ptr_struct` - builder->CreateStore( - tlctx->get_constant(num_bits), - builder->CreateGEP(bit_ptr_struct, - {tlctx->get_constant(0), tlctx->get_constant(2)})); return bit_ptr_struct; } @@ -1312,9 +1297,7 @@ void CodeGenLLVM::visit(SNodeLookupStmt *stmt) { auto offset = tlctx->get_constant(element_num_bits); offset = builder->CreateMul(offset, llvm_val[stmt->input_index]); llvm_val[stmt] = create_bit_ptr_struct( - llvm_val[stmt->input_snode], offset, - data_type_bits( - snode->dt.get_ptr()->as()->get_physical_type())); + llvm_val[stmt->input_snode], offset); } else { TI_INFO(snode_type_name(snode->type)); TI_NOT_IMPLEMENTED @@ -1330,8 +1313,7 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { stmt->input_snode->child_id(stmt->output_snode)); auto offset = tlctx->get_constant(bit_offset); llvm_val[stmt] = - create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset, - data_type_bits(bit_struct->get_physical_type())); + create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset); } else { auto ch = create_call(stmt->output_snode->get_ch_from_parent_func_name(), {builder->CreateBitCast( diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index d89498f71c7cb..bf32b5f3b9253 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -209,8 +209,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(IntegerOffsetStmt *stmt) override; llvm::Value *create_bit_ptr_struct(llvm::Value *byte_ptr_base, - llvm::Value *bit_offset, - int num_bits = 32); + llvm::Value *bit_offset); void visit(SNodeLookupStmt *stmt) override; diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 061259cc985f2..02ee106e4e33f 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -102,7 +102,8 @@ std::string CustomIntType::to_string() const { } CustomIntType::CustomIntType(int num_bits, bool is_signed) - : compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) { + : compute_type(nullptr), physical_type(nullptr), + num_bits_(num_bits), is_signed_(is_signed) { // TODO(type): support customizable compute_type // and should we expose it to users? TI_ASSERT(num_bits <= 32); @@ -126,7 +127,8 @@ CustomIntType::CustomIntType(int num_bits, bool is_signed) CustomIntType::CustomIntType(int compute_type_bits, int num_bits, bool is_signed) - : compute_type(nullptr), num_bits_(num_bits), is_signed_(is_signed) { + : compute_type(nullptr), physical_type(nullptr), + num_bits_(num_bits), is_signed_(is_signed) { auto type_id = PrimitiveTypeID::unknown; if (false) { } @@ -138,6 +140,23 @@ CustomIntType::CustomIntType(int compute_type_bits, TypeFactory::get_instance().get_primitive_type(type_id); } +CustomIntType::CustomIntType(int compute_type_bits, + Type *physical_type, + int num_bits, + bool is_signed) + : compute_type(nullptr), physical_type(physical_type), + num_bits_(num_bits), is_signed_(is_signed) { + auto type_id = PrimitiveTypeID::unknown; + if (false) { + } + SET_COMPUTE_TYPE(compute_type_bits, 64) + SET_COMPUTE_TYPE(compute_type_bits, 32) + SET_COMPUTE_TYPE(compute_type_bits, 16) + SET_COMPUTE_TYPE(compute_type_bits, 8) + else {TI_NOT_IMPLEMENTED} compute_type = + TypeFactory::get_instance().get_primitive_type(type_id); +} + BitStructType::BitStructType(PrimitiveType *physical_type, std::vector member_types, std::vector member_bit_offsets) diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 6c02aa129463e..03dd4b7308730 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -170,12 +170,23 @@ class CustomIntType : public Type { CustomIntType(int compute_type_bits, int numBits, bool isSigned); + CustomIntType(int compute_type_bits, Type *physical_type, int num_bits, + bool is_signed); + ~CustomIntType() override { delete compute_type; } std::string to_string() const override; + void set_physical_type(Type*physical_type_) { + this->physical_type = physical_type_; + } + + Type* get_physical_type() { + return physical_type; + } + Type *get_compute_type() { return compute_type; } @@ -192,6 +203,7 @@ class CustomIntType : public Type { // TODO(type): for now we can uniformly use i32 as the "compute_type". It may // be a good idea to make "compute_type" also customizable. Type *compute_type{nullptr}; + Type *physical_type{nullptr}; int num_bits_{32}; bool is_signed_{true}; }; @@ -201,16 +213,6 @@ class BitStructType : public Type { BitStructType(PrimitiveType *physical_type, std::vector member_types, std::vector member_bit_offsets); - // : physical_type_(physical_type), - // member_types_(member_types), - // member_bit_offsets_(member_bit_offsets) { - // TI_ASSERT(member_types_.size() == member_bit_offsets_.size()); - // int physical_type_bits = data_type_bits(physical_type); - // for (auto i = 0; i < member_types_.size(); ++i) { - // auto bits_end = member_types_[i]->as()->get_num_bits() - // + member_bit_offsets_[i]; TI_ASSERT(physical_type_bits >= bits_end) - // } - // } std::string to_string() const override; diff --git a/taichi/llvm/llvm_codegen_utils.h b/taichi/llvm/llvm_codegen_utils.h index 3247c29d88edf..bd9c1e717d334 100644 --- a/taichi/llvm/llvm_codegen_utils.h +++ b/taichi/llvm/llvm_codegen_utils.h @@ -128,8 +128,7 @@ class LLVMModuleBuilder { void read_bit_pointer(llvm::Value *ptr, llvm::Value *&byte_ptr, - llvm::Value *&bit_offset, - llvm::Value *&physical_type) { + llvm::Value *&bit_offset) { // 1. load byte pointer auto byte_ptr_in_bit_struct = builder->CreateGEP( ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}); @@ -141,11 +140,6 @@ class LLVMModuleBuilder { ptr, {tlctx->get_constant(0), tlctx->get_constant(1)}); bit_offset = builder->CreateLoad(bit_offset_in_bit_struct); TI_ASSERT(bit_offset->getType()->isIntegerTy(32)); - - auto physical_type_bit_struct = builder->CreateGEP( - ptr, {tlctx->get_constant(0), tlctx->get_constant(2)}); - physical_type = builder->CreateLoad(physical_type_bit_struct); - TI_ASSERT(bit_offset->getType()->isIntegerTy(32)); } }; diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index d8ae7742ab872..4ab593c1303a8 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -1572,60 +1572,6 @@ DEFINE_SET_PARTIAL_BITS(16); DEFINE_SET_PARTIAL_BITS(32); DEFINE_SET_PARTIAL_BITS(64); -void set_partial_bits(u32 *ptr, u32 offset, u32 bits, u32 value, u32 n) { -#define CALL_SET_PARTIAL_BITS_FUNC(N) \ - else if (n == N) { \ - set_partial_bits_b##N((u##N *)ptr, offset, bits, (u##N)value); \ - } - - if (false) { - } - CALL_SET_PARTIAL_BITS_FUNC(8) - CALL_SET_PARTIAL_BITS_FUNC(16) - CALL_SET_PARTIAL_BITS_FUNC(32) - CALL_SET_PARTIAL_BITS_FUNC(64) - else { - assert(false); - } -} - -#define DEFINE_LOAD_PARTIAL_BITS(s, N) \ - int64 load_partial_bits_##s##N(i8 *ptr, u32 offset, u32 bits) { \ - s##N value = *(s##N *)ptr; \ - value = (value << (N - offset - bits)); \ - value = (value >> (N - bits)); \ - return value; \ - } - -DEFINE_LOAD_PARTIAL_BITS(i, 8) -DEFINE_LOAD_PARTIAL_BITS(i, 16) -DEFINE_LOAD_PARTIAL_BITS(i, 32) -DEFINE_LOAD_PARTIAL_BITS(i, 64) - -DEFINE_LOAD_PARTIAL_BITS(u, 8) -DEFINE_LOAD_PARTIAL_BITS(u, 16) -DEFINE_LOAD_PARTIAL_BITS(u, 32) -DEFINE_LOAD_PARTIAL_BITS(u, 64) - -int64 load_partial_bits(i8 *ptr, u32 offset, u32 bits, u32 n, u32 is_signed) { -#define CALL_LOAD_PARTIAL_BITS_FUNC(N) \ - else if (n == N) { \ - if (is_signed) \ - return load_partial_bits_i##N(ptr, offset, bits); \ - else \ - return load_partial_bits_u##N(ptr, offset, bits); \ - } - if (false) { - } - CALL_LOAD_PARTIAL_BITS_FUNC(8) - CALL_LOAD_PARTIAL_BITS_FUNC(16) - CALL_LOAD_PARTIAL_BITS_FUNC(32) - CALL_LOAD_PARTIAL_BITS_FUNC(64) - else { - assert(false); - } - return 0; -} } #endif diff --git a/taichi/struct/struct_llvm.cpp b/taichi/struct/struct_llvm.cpp index 2f760697a5d06..a32ca8d64ce00 100644 --- a/taichi/struct/struct_llvm.cpp +++ b/taichi/struct/struct_llvm.cpp @@ -68,6 +68,7 @@ void StructCompilerLLVM::generate_types(SNode &snode) { ch_types.push_back(ch->dt.get_ptr()); ch_offsets.push_back(total_offset); total_offset += ch->dt->as()->get_num_bits(); + ch->dt->as()->set_physical_type(snode.physical_type); } snode.dt = TypeFactory::get_instance().get_bit_struct_type( From f045ec172c6648c5af2933c6cd2d4b94cacc34ca Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Thu, 19 Nov 2020 09:59:01 +0800 Subject: [PATCH 19/32] set physical type for bit_array --- taichi/struct/struct_llvm.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/taichi/struct/struct_llvm.cpp b/taichi/struct/struct_llvm.cpp index a32ca8d64ce00..89ed21b360d86 100644 --- a/taichi/struct/struct_llvm.cpp +++ b/taichi/struct/struct_llvm.cpp @@ -81,6 +81,7 @@ void StructCompilerLLVM::generate_types(SNode &snode) { TI_ASSERT(snode.ch.size() == 1); auto &ch = snode.ch[0]; Type *ch_type = ch->dt.get_ptr(); + ch->dt->as()->set_physical_type(snode.physical_type); snode.dt = TypeFactory::get_instance().get_bit_array_type( snode.physical_type, ch_type, snode.n); From e344ca03ef90db005c6bcb0dd865e38e33fd0074 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Wed, 18 Nov 2020 21:52:09 -0500 Subject: [PATCH 20/32] [skip ci] enforce code format --- taichi/codegen/codegen_llvm.cpp | 23 +++++++++++++---------- taichi/ir/type.cpp | 20 +++++++++++++------- taichi/ir/type.h | 8 +++++--- taichi/runtime/llvm/runtime.cpp | 1 - 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index aa9772de51507..28fe0900de16f 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1110,13 +1110,15 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { auto cit = ptr_type->get_pointee_type()->as(); llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr; read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); - auto func_name = fmt::format("set_partial_bits_b{}", data_type_bits(cit->get_physical_type())); + auto func_name = fmt::format("set_partial_bits_b{}", + data_type_bits(cit->get_physical_type())); builder->CreateCall( get_runtime_function(func_name), {builder->CreateBitCast(byte_ptr, llvm_ptr_type(cit->get_physical_type())), bit_offset, tlctx->get_constant(cit->get_num_bits()), - builder->CreateIntCast(llvm_val[stmt->data], llvm_type(cit->get_physical_type()), false)}); + builder->CreateIntCast(llvm_val[stmt->data], + llvm_type(cit->get_physical_type()), false)}); } else { builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); } @@ -1138,11 +1140,14 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { // then right shift `physical_type - num_bits` auto bit_end = builder->CreateAdd(bit_offset, tlctx->get_constant(cit->get_num_bits())); - auto left = builder->CreateSub(tlctx->get_constant(data_type_bits(cit->get_physical_type())), bit_end); - auto right = builder->CreateSub(tlctx->get_constant(data_type_bits(cit->get_physical_type())), - tlctx->get_constant(cit->get_num_bits())); + auto left = builder->CreateSub( + tlctx->get_constant(data_type_bits(cit->get_physical_type())), bit_end); + auto right = builder->CreateSub( + tlctx->get_constant(data_type_bits(cit->get_physical_type())), + tlctx->get_constant(cit->get_num_bits())); left = builder->CreateIntCast(left, bit_level_container->getType(), false); - right = builder->CreateIntCast(right, bit_level_container->getType(), false); + right = + builder->CreateIntCast(right, bit_level_container->getType(), false); auto step1 = builder->CreateShl(bit_level_container, left); llvm::Value *step2 = nullptr; if (cit->get_is_signed()) @@ -1296,8 +1301,7 @@ void CodeGenLLVM::visit(SNodeLookupStmt *stmt) { snode->dt.get_ptr()->as()->get_element_num_bits(); auto offset = tlctx->get_constant(element_num_bits); offset = builder->CreateMul(offset, llvm_val[stmt->input_index]); - llvm_val[stmt] = create_bit_ptr_struct( - llvm_val[stmt->input_snode], offset); + llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_snode], offset); } else { TI_INFO(snode_type_name(snode->type)); TI_NOT_IMPLEMENTED @@ -1312,8 +1316,7 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { auto bit_offset = bit_struct->get_member_bit_offset( stmt->input_snode->child_id(stmt->output_snode)); auto offset = tlctx->get_constant(bit_offset); - llvm_val[stmt] = - create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset); + llvm_val[stmt] = create_bit_ptr_struct(llvm_val[stmt->input_ptr], offset); } else { auto ch = create_call(stmt->output_snode->get_ch_from_parent_func_name(), {builder->CreateBitCast( diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 02ee106e4e33f..6a01573d47d65 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -102,8 +102,10 @@ std::string CustomIntType::to_string() const { } CustomIntType::CustomIntType(int num_bits, bool is_signed) - : compute_type(nullptr), physical_type(nullptr), - num_bits_(num_bits), is_signed_(is_signed) { + : compute_type(nullptr), + physical_type(nullptr), + num_bits_(num_bits), + is_signed_(is_signed) { // TODO(type): support customizable compute_type // and should we expose it to users? TI_ASSERT(num_bits <= 32); @@ -127,8 +129,10 @@ CustomIntType::CustomIntType(int num_bits, bool is_signed) CustomIntType::CustomIntType(int compute_type_bits, int num_bits, bool is_signed) - : compute_type(nullptr), physical_type(nullptr), - num_bits_(num_bits), is_signed_(is_signed) { + : compute_type(nullptr), + physical_type(nullptr), + num_bits_(num_bits), + is_signed_(is_signed) { auto type_id = PrimitiveTypeID::unknown; if (false) { } @@ -144,8 +148,10 @@ CustomIntType::CustomIntType(int compute_type_bits, Type *physical_type, int num_bits, bool is_signed) - : compute_type(nullptr), physical_type(physical_type), - num_bits_(num_bits), is_signed_(is_signed) { + : compute_type(nullptr), + physical_type(physical_type), + num_bits_(num_bits), + is_signed_(is_signed) { auto type_id = PrimitiveTypeID::unknown; if (false) { } @@ -154,7 +160,7 @@ CustomIntType::CustomIntType(int compute_type_bits, SET_COMPUTE_TYPE(compute_type_bits, 16) SET_COMPUTE_TYPE(compute_type_bits, 8) else {TI_NOT_IMPLEMENTED} compute_type = - TypeFactory::get_instance().get_primitive_type(type_id); + TypeFactory::get_instance().get_primitive_type(type_id); } BitStructType::BitStructType(PrimitiveType *physical_type, diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 03dd4b7308730..b6f03b79ef5b2 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -170,7 +170,9 @@ class CustomIntType : public Type { CustomIntType(int compute_type_bits, int numBits, bool isSigned); - CustomIntType(int compute_type_bits, Type *physical_type, int num_bits, + CustomIntType(int compute_type_bits, + Type *physical_type, + int num_bits, bool is_signed); ~CustomIntType() override { @@ -179,11 +181,11 @@ class CustomIntType : public Type { std::string to_string() const override; - void set_physical_type(Type*physical_type_) { + void set_physical_type(Type *physical_type_) { this->physical_type = physical_type_; } - Type* get_physical_type() { + Type *get_physical_type() { return physical_type; } diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index 4ab593c1303a8..ba52ebc67826c 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -1571,7 +1571,6 @@ DEFINE_SET_PARTIAL_BITS(8); DEFINE_SET_PARTIAL_BITS(16); DEFINE_SET_PARTIAL_BITS(32); DEFINE_SET_PARTIAL_BITS(64); - } #endif From a20a0e86dfd9366e78087f8c580a02ad87ae75ff Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Thu, 19 Nov 2020 15:06:28 +0800 Subject: [PATCH 21/32] modify test cases --- tests/python/test_bit_struct.py | 147 ++++++++------------------------ 1 file changed, 37 insertions(+), 110 deletions(-) diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index 5cc45671e5f09..d29f426ceadfb 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -77,113 +77,40 @@ def verify_val(idx: ti.i32): verify_val.__wrapped__(idx) -@ti.test(arch=ti.cpu, debug=True, cfg_optimization=False) -def test_bit_struct_with_physical_type(): - ci8_5 = ti.type_factory_._get_custom_int_type(8, 5, True) - ci8_3 = ti.type_factory_._get_custom_int_type(8, 3, False) - ci16_4 = ti.type_factory_._get_custom_int_type(16, 4, True) - cu16_12 = ti.type_factory_._get_custom_int_type(16, 12, False) - - ci32_17 = ti.type_factory_._get_custom_int_type(32, 17, True) - ci32_11 = ti.type_factory_._get_custom_int_type(32, 11, True) - cu32_4 = ti.type_factory_._get_custom_int_type(32, 4, False) - - ci64_33 = ti.type_factory_._get_custom_int_type(64, 32, True) - ci64_20 = ti.type_factory_._get_custom_int_type(64, 21, False) - ci64_7 = ti.type_factory_._get_custom_int_type(64, 7, False) - - a = ti.field(dtype=ci8_5) - b = ti.field(dtype=ci8_3) - - c = ti.field(dtype=ci16_4) - d = ti.field(dtype=cu16_12) - - e = ti.field(dtype=ci32_17) - f = ti.field(dtype=ci32_11) - g = ti.field(dtype=cu32_4) - - h = ti.field(dtype=ci64_33) - i = ti.field(dtype=ci64_20) - j = ti.field(dtype=ci64_7) - - k = ti.field(dtype=ci16_4) - l = ti.field(dtype=cu16_12) - - m = ti.field(dtype=ci32_17) - n = ti.field(dtype=ci32_11) - o = ti.field(dtype=cu32_4) - - ti.root._bit_struct(num_bits=8).place(a, b) - ti.root._bit_struct(num_bits=16).place(c, d) - ti.root._bit_struct(num_bits=32).place(e, f, g) - ti.root._bit_struct(num_bits=64).place(h, i, j) - - ti.root._bit_struct(num_bits=32).place(k, l) - ti.root._bit_struct(num_bits=64).place(m, n, o) - - test_case_np = np.array( - [[ - 2**4 - 1, 2**3 - 1, 2**3 - 1, 2**12 - 1, 2**16 - 1, 2**10 - 1, - 2**4 - 1, 2**31 - 1, 2**21 - 1, 2**7 - 1, 2**3 - 1, 2**12 - 1, - 2**16 - 1, 2**10 - 1, 2**4 - 1 - ], - [ - -2**3, - 2**2 - 1, - -2**2, - 2**11 - 1, - -2**15, - -2**9, - 2**2 - 1, - -2**30, - 2**20 - 1, - 2**6 - 1, - -2**2, - 2**11 - 1, - -2**15, - -2**9, - 2**2 - 1, - ], [3, 4, 5, 16, 21, 34, 1, 2020, 456, 123, 5, 16, 21, 34, 1]], - dtype=np.int32) - test_case = ti.Vector.field(15, dtype=ti.i32, shape=len(test_case_np)) - test_case.from_numpy(test_case_np) - - @ti.kernel - def set_val(idx: ti.i32): - a[None] = test_case[idx][0] - b[None] = test_case[idx][1] - c[None] = test_case[idx][2] - d[None] = test_case[idx][3] - e[None] = test_case[idx][4] - f[None] = test_case[idx][5] - g[None] = test_case[idx][6] - h[None] = test_case[idx][7] - i[None] = test_case[idx][8] - j[None] = test_case[idx][9] - k[None] = test_case[idx][10] - l[None] = test_case[idx][11] - m[None] = test_case[idx][12] - n[None] = test_case[idx][13] - o[None] = test_case[idx][14] - - @ti.kernel - def verify_val(idx: ti.i32): - assert a[None] == test_case[idx][0] - assert b[None] == test_case[idx][1] - assert c[None] == test_case[idx][2] - assert d[None] == test_case[idx][3] - assert e[None] == test_case[idx][4] - assert f[None] == test_case[idx][5] - assert g[None] == test_case[idx][6] - assert h[None] == test_case[idx][7] - assert i[None] == test_case[idx][8] - assert j[None] == test_case[idx][9] - assert k[None] == test_case[idx][10] - assert l[None] == test_case[idx][11] - assert m[None] == test_case[idx][12] - assert n[None] == test_case[idx][13] - assert o[None] == test_case[idx][14] - - for idx in range(len(test_case_np)): - set_val(idx) - verify_val(idx) +def test_bit_struct(): + def test_single_bit_struct(physical_type, compute_type, custom_bits, test_case): + ti.init(arch=ti.cpu, debug=True, print_ir=False, cfg_optimization=False) + + cit1 = ti.type_factory_._get_custom_int_type(compute_type, custom_bits[0], True) + cit2 = ti.type_factory_._get_custom_int_type(compute_type, custom_bits[1], False) + cit3 = ti.type_factory_._get_custom_int_type(compute_type, custom_bits[2], True) + + a = ti.field(dtype=cit1) + b = ti.field(dtype=cit2) + c = ti.field(dtype=cit3) + ti.root._bit_struct(num_bits=physical_type).place(a, b, c) + + @ti.kernel + def set_val(test_val: ti.ext_arr()): + a[None] = test_val[0] + b[None] = test_val[1] + c[None] = test_val[2] + + @ti.kernel + def verify_val(test_val: ti.ext_arr()): + assert a[None] == test_val[0] + assert b[None] == test_val[1] + assert c[None] == test_val[2] + + set_val(test_case) + verify_val(test_case) + + test_single_bit_struct(8, 8, [3, 3, 2], np.array([2**2-1, 2**3-1, -2**1])) + test_single_bit_struct(16, 16, [4, 7, 5], np.array([2**3-1, 2**7-1, -2**4])) + test_single_bit_struct(32, 32, [17, 11, 4], np.array([2**16-1, 2**10-1, -2**3])) + test_single_bit_struct(64, 64, [32, 23, 9], np.array([2**31-1, 2**23-1, -2**8])) + test_single_bit_struct(32, 16, [7, 12, 13], np.array([2**6-1, 2**12-1, -2**12])) + test_single_bit_struct(64, 32, [18, 22, 24], np.array([2**17-1, 2**22-1, -2**23])) + + test_single_bit_struct(16, 16, [5, 5, 6], np.array([15, 5, 20])) + test_single_bit_struct(32, 32, [10, 10, 12], np.array([11, 19, 2020])) From 18a67381b523920fb954262c5864f3c091b9b35c Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Thu, 19 Nov 2020 02:07:42 -0500 Subject: [PATCH 22/32] [skip ci] enforce code format --- tests/python/test_bit_struct.py | 37 ++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index d29f426ceadfb..8fe2bed9c6a97 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -78,12 +78,19 @@ def verify_val(idx: ti.i32): def test_bit_struct(): - def test_single_bit_struct(physical_type, compute_type, custom_bits, test_case): - ti.init(arch=ti.cpu, debug=True, print_ir=False, cfg_optimization=False) - - cit1 = ti.type_factory_._get_custom_int_type(compute_type, custom_bits[0], True) - cit2 = ti.type_factory_._get_custom_int_type(compute_type, custom_bits[1], False) - cit3 = ti.type_factory_._get_custom_int_type(compute_type, custom_bits[2], True) + def test_single_bit_struct(physical_type, compute_type, custom_bits, + test_case): + ti.init(arch=ti.cpu, + debug=True, + print_ir=False, + cfg_optimization=False) + + cit1 = ti.type_factory_._get_custom_int_type(compute_type, + custom_bits[0], True) + cit2 = ti.type_factory_._get_custom_int_type(compute_type, + custom_bits[1], False) + cit3 = ti.type_factory_._get_custom_int_type(compute_type, + custom_bits[2], True) a = ti.field(dtype=cit1) b = ti.field(dtype=cit2) @@ -105,12 +112,18 @@ def verify_val(test_val: ti.ext_arr()): set_val(test_case) verify_val(test_case) - test_single_bit_struct(8, 8, [3, 3, 2], np.array([2**2-1, 2**3-1, -2**1])) - test_single_bit_struct(16, 16, [4, 7, 5], np.array([2**3-1, 2**7-1, -2**4])) - test_single_bit_struct(32, 32, [17, 11, 4], np.array([2**16-1, 2**10-1, -2**3])) - test_single_bit_struct(64, 64, [32, 23, 9], np.array([2**31-1, 2**23-1, -2**8])) - test_single_bit_struct(32, 16, [7, 12, 13], np.array([2**6-1, 2**12-1, -2**12])) - test_single_bit_struct(64, 32, [18, 22, 24], np.array([2**17-1, 2**22-1, -2**23])) + test_single_bit_struct(8, 8, [3, 3, 2], + np.array([2**2 - 1, 2**3 - 1, -2**1])) + test_single_bit_struct(16, 16, [4, 7, 5], + np.array([2**3 - 1, 2**7 - 1, -2**4])) + test_single_bit_struct(32, 32, [17, 11, 4], + np.array([2**16 - 1, 2**10 - 1, -2**3])) + test_single_bit_struct(64, 64, [32, 23, 9], + np.array([2**31 - 1, 2**23 - 1, -2**8])) + test_single_bit_struct(32, 16, [7, 12, 13], + np.array([2**6 - 1, 2**12 - 1, -2**12])) + test_single_bit_struct(64, 32, [18, 22, 24], + np.array([2**17 - 1, 2**22 - 1, -2**23])) test_single_bit_struct(16, 16, [5, 5, 6], np.array([15, 5, 20])) test_single_bit_struct(32, 32, [10, 10, 12], np.array([11, 19, 2020])) From b8c4d493c10e8da912ba74d785fe703cdb2dfa93 Mon Sep 17 00:00:00 2001 From: Jiafeng Liu Date: Thu, 19 Nov 2020 15:17:04 +0800 Subject: [PATCH 23/32] Apply suggestions from code review Co-authored-by: Yuanming Hu --- taichi/ir/type.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 6a01573d47d65..5398e379dd539 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -107,7 +107,7 @@ CustomIntType::CustomIntType(int num_bits, bool is_signed) num_bits_(num_bits), is_signed_(is_signed) { // TODO(type): support customizable compute_type - // and should we expose it to users? + // and expose it to users in the future. TI_ASSERT(num_bits <= 32); if (is_signed) { compute_type = From a7acc5783ad070aa743bb530d0686cd5aa2ac775 Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Thu, 19 Nov 2020 21:06:15 +0800 Subject: [PATCH 24/32] fix type factory --- taichi/ir/type_factory.cpp | 2 +- taichi/ir/type_factory.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index f8d2eb85dfe6c..00622cd840eac 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -49,7 +49,7 @@ Type *TypeFactory::get_custom_int_type(int num_bits, bool is_signed) { Type *TypeFactory::_get_custom_int_type(int compute_type_bits, int num_bits, bool is_signed) { - auto key = std::make_pair(num_bits, is_signed); + auto key = std::make_tuple(compute_type_bits, num_bits, is_signed); if (custom_int_types_with_compute_types_.find(key) == custom_int_types_with_compute_types_.end()) { custom_int_types_with_compute_types_[key] = diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index e39838c7c030f..bd8032b1144a4 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -53,7 +53,7 @@ class TypeFactory { // TODO: use unordered map std::map, std::unique_ptr> custom_int_types_; - std::map, std::unique_ptr> + std::map, std::unique_ptr> custom_int_types_with_compute_types_; // TODO: avoid duplication From 704bbe72ec6ac700756e22a2e145e87d7d14c233 Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Thu, 19 Nov 2020 21:21:01 +0800 Subject: [PATCH 25/32] change funtion name --- taichi/ir/type_factory.cpp | 2 +- taichi/ir/type_factory.h | 2 +- taichi/python/export_lang.cpp | 3 ++- tests/python/test_bit_struct.py | 6 +++--- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 00622cd840eac..3d718e3c72a94 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -46,7 +46,7 @@ Type *TypeFactory::get_custom_int_type(int num_bits, bool is_signed) { return custom_int_types_[key].get(); } -Type *TypeFactory::_get_custom_int_type(int compute_type_bits, +Type *TypeFactory::get_custom_int_type_with_compute_type(int compute_type_bits, int num_bits, bool is_signed) { auto key = std::make_tuple(compute_type_bits, num_bits, is_signed); diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index bd8032b1144a4..ba126f6e283c0 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -23,7 +23,7 @@ class TypeFactory { Type *get_custom_int_type(int num_bits, bool is_signed); - Type *_get_custom_int_type(int compute_type_bits, + Type *get_custom_int_type_with_compute_type(int compute_type_bits, int num_bits, bool is_signed); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 8721940ab1ae4..6607494ad239c 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -729,7 +729,8 @@ void export_lang(py::module &m) { py::class_(m, "TypeFactory") .def("get_custom_int_type", &TypeFactory::get_custom_int_type, py::return_value_policy::reference) - .def("_get_custom_int_type", &TypeFactory::_get_custom_int_type, + .def("get_custom_int_type_with_compute_type", + &TypeFactory::get_custom_int_type_with_compute_type, py::return_value_policy::reference); m.def("get_type_factory_instance", TypeFactory::get_instance, diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index 8fe2bed9c6a97..638d9d2330bf1 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -85,11 +85,11 @@ def test_single_bit_struct(physical_type, compute_type, custom_bits, print_ir=False, cfg_optimization=False) - cit1 = ti.type_factory_._get_custom_int_type(compute_type, + cit1 = ti.type_factory_.get_custom_int_type_with_compute_type(compute_type, custom_bits[0], True) - cit2 = ti.type_factory_._get_custom_int_type(compute_type, + cit2 = ti.type_factory_.get_custom_int_type_with_compute_type(compute_type, custom_bits[1], False) - cit3 = ti.type_factory_._get_custom_int_type(compute_type, + cit3 = ti.type_factory_.get_custom_int_type_with_compute_type(compute_type, custom_bits[2], True) a = ti.field(dtype=cit1) From 5022a10514754388ccc8fee94840f9b4834d5340 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Thu, 19 Nov 2020 08:21:39 -0500 Subject: [PATCH 26/32] [skip ci] enforce code format --- taichi/ir/type_factory.cpp | 4 ++-- taichi/ir/type_factory.h | 4 ++-- tests/python/test_bit_struct.py | 12 ++++++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 3d718e3c72a94..f13ff58e44420 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -47,8 +47,8 @@ Type *TypeFactory::get_custom_int_type(int num_bits, bool is_signed) { } Type *TypeFactory::get_custom_int_type_with_compute_type(int compute_type_bits, - int num_bits, - bool is_signed) { + int num_bits, + bool is_signed) { auto key = std::make_tuple(compute_type_bits, num_bits, is_signed); if (custom_int_types_with_compute_types_.find(key) == custom_int_types_with_compute_types_.end()) { diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index ba126f6e283c0..8f53443242914 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -24,8 +24,8 @@ class TypeFactory { Type *get_custom_int_type(int num_bits, bool is_signed); Type *get_custom_int_type_with_compute_type(int compute_type_bits, - int num_bits, - bool is_signed); + int num_bits, + bool is_signed); Type *get_bit_struct_type(PrimitiveType *physical_type, std::vector member_types, diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index 638d9d2330bf1..abf239a7dda42 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -85,12 +85,12 @@ def test_single_bit_struct(physical_type, compute_type, custom_bits, print_ir=False, cfg_optimization=False) - cit1 = ti.type_factory_.get_custom_int_type_with_compute_type(compute_type, - custom_bits[0], True) - cit2 = ti.type_factory_.get_custom_int_type_with_compute_type(compute_type, - custom_bits[1], False) - cit3 = ti.type_factory_.get_custom_int_type_with_compute_type(compute_type, - custom_bits[2], True) + cit1 = ti.type_factory_.get_custom_int_type_with_compute_type( + compute_type, custom_bits[0], True) + cit2 = ti.type_factory_.get_custom_int_type_with_compute_type( + compute_type, custom_bits[1], False) + cit3 = ti.type_factory_.get_custom_int_type_with_compute_type( + compute_type, custom_bits[2], True) a = ti.field(dtype=cit1) b = ti.field(dtype=cit2) From c1bf0e6699cb6bedebb15e91552768de858777bf Mon Sep 17 00:00:00 2001 From: Jiafeng Liu Date: Fri, 20 Nov 2020 09:52:33 +0800 Subject: [PATCH 27/32] Apply suggestions from code review Co-authored-by: Yuanming Hu --- taichi/ir/type.cpp | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 5398e379dd539..3158fc25c8068 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -129,20 +129,7 @@ CustomIntType::CustomIntType(int num_bits, bool is_signed) CustomIntType::CustomIntType(int compute_type_bits, int num_bits, bool is_signed) - : compute_type(nullptr), - physical_type(nullptr), - num_bits_(num_bits), - is_signed_(is_signed) { - auto type_id = PrimitiveTypeID::unknown; - if (false) { - } - SET_COMPUTE_TYPE(compute_type_bits, 64) - SET_COMPUTE_TYPE(compute_type_bits, 32) - SET_COMPUTE_TYPE(compute_type_bits, 16) - SET_COMPUTE_TYPE(compute_type_bits, 8) - else {TI_NOT_IMPLEMENTED} compute_type = - TypeFactory::get_instance().get_primitive_type(type_id); -} + : CustomIntType(compute_type_bits, nullptr, num_bits, is_signed) {} CustomIntType::CustomIntType(int compute_type_bits, Type *physical_type, @@ -162,6 +149,7 @@ CustomIntType::CustomIntType(int compute_type_bits, else {TI_NOT_IMPLEMENTED} compute_type = TypeFactory::get_instance().get_primitive_type(type_id); } +#undef SET_COMPUTE_TYPE BitStructType::BitStructType(PrimitiveType *physical_type, std::vector member_types, From cc8878b2f0d14c91302ec2d01befcebe2711a27d Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Thu, 19 Nov 2020 20:52:53 -0500 Subject: [PATCH 28/32] [skip ci] enforce code format --- taichi/ir/type.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 3158fc25c8068..fe2c1f93946c4 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -129,7 +129,8 @@ CustomIntType::CustomIntType(int num_bits, bool is_signed) CustomIntType::CustomIntType(int compute_type_bits, int num_bits, bool is_signed) - : CustomIntType(compute_type_bits, nullptr, num_bits, is_signed) {} + : CustomIntType(compute_type_bits, nullptr, num_bits, is_signed) { +} CustomIntType::CustomIntType(int compute_type_bits, Type *physical_type, From a2d41969dc6965d7fd2892bd7aa105c5c68dbe4b Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Fri, 20 Nov 2020 11:15:28 +0800 Subject: [PATCH 29/32] modify APIs --- taichi/ir/type.cpp | 55 ++++++--------------------------- taichi/ir/type.h | 10 ++---- taichi/ir/type_factory.cpp | 20 +++++------- taichi/ir/type_factory.h | 8 ++--- taichi/python/export_lang.cpp | 4 +-- tests/python/test_bit_struct.py | 12 +++---- 6 files changed, 28 insertions(+), 81 deletions(-) diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index fe2c1f93946c4..be0bf2859aa48 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -101,56 +101,19 @@ std::string CustomIntType::to_string() const { return fmt::format("c{}{}", is_signed_ ? 'i' : 'u', num_bits_); } -CustomIntType::CustomIntType(int num_bits, bool is_signed) - : compute_type(nullptr), - physical_type(nullptr), - num_bits_(num_bits), - is_signed_(is_signed) { - // TODO(type): support customizable compute_type - // and expose it to users in the future. - TI_ASSERT(num_bits <= 32); - if (is_signed) { - compute_type = - TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i32); - } else { - compute_type = - TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u32); - } -} - -#define SET_COMPUTE_TYPE(n, N) \ - else if (n == N) { \ - if (is_signed) \ - type_id = PrimitiveTypeID::i##N; \ - else \ - type_id = PrimitiveTypeID::u##N; \ - } - -CustomIntType::CustomIntType(int compute_type_bits, - int num_bits, - bool is_signed) - : CustomIntType(compute_type_bits, nullptr, num_bits, is_signed) { -} - -CustomIntType::CustomIntType(int compute_type_bits, - Type *physical_type, - int num_bits, - bool is_signed) - : compute_type(nullptr), +CustomIntType::CustomIntType(int num_bits, + bool is_signed, + Type* compute_type, + Type *physical_type) + : compute_type(compute_type), physical_type(physical_type), num_bits_(num_bits), is_signed_(is_signed) { - auto type_id = PrimitiveTypeID::unknown; - if (false) { + if (compute_type == nullptr) { + auto type_id = is_signed ? PrimitiveTypeID::i32 : PrimitiveTypeID::u32; + this->compute_type = TypeFactory::get_instance().get_primitive_type(type_id); } - SET_COMPUTE_TYPE(compute_type_bits, 64) - SET_COMPUTE_TYPE(compute_type_bits, 32) - SET_COMPUTE_TYPE(compute_type_bits, 16) - SET_COMPUTE_TYPE(compute_type_bits, 8) - else {TI_NOT_IMPLEMENTED} compute_type = - TypeFactory::get_instance().get_primitive_type(type_id); -} -#undef SET_COMPUTE_TYPE +} BitStructType::BitStructType(PrimitiveType *physical_type, std::vector member_types, diff --git a/taichi/ir/type.h b/taichi/ir/type.h index b6f03b79ef5b2..b3a19bae25583 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -166,14 +166,8 @@ class VectorType : public Type { class CustomIntType : public Type { public: - CustomIntType(int num_bits, bool is_signed); - - CustomIntType(int compute_type_bits, int numBits, bool isSigned); - - CustomIntType(int compute_type_bits, - Type *physical_type, - int num_bits, - bool is_signed); + CustomIntType(int num_bits, bool is_signed, + Type* compute_type=nullptr, Type* physical_type=nullptr); ~CustomIntType() override { delete compute_type; diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index f13ff58e44420..0a28be1561981 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -37,27 +37,21 @@ Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) { return pointer_types_[key].get(); } -Type *TypeFactory::get_custom_int_type(int num_bits, bool is_signed) { - auto key = std::make_pair(num_bits, is_signed); - if (custom_int_types_.find(key) == custom_int_types_.end()) { - custom_int_types_[key] = - std::make_unique(num_bits, is_signed); - } - return custom_int_types_[key].get(); -} - -Type *TypeFactory::get_custom_int_type_with_compute_type(int compute_type_bits, - int num_bits, - bool is_signed) { +Type *TypeFactory::get_custom_int_type(int num_bits, + bool is_signed, + int compute_type_bits) { auto key = std::make_tuple(compute_type_bits, num_bits, is_signed); if (custom_int_types_with_compute_types_.find(key) == custom_int_types_with_compute_types_.end()) { custom_int_types_with_compute_types_[key] = - std::make_unique(compute_type_bits, num_bits, is_signed); + std::make_unique(num_bits, is_signed, + get_primitive_int_type(compute_type_bits, is_signed)); } return custom_int_types_with_compute_types_[key].get(); } +#undef SET_COMPUTE_TYPE + Type *TypeFactory::get_bit_struct_type(PrimitiveType *physical_type, std::vector member_types, std::vector member_bit_offsets) { diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index 8f53443242914..a11d27df68178 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -21,11 +21,9 @@ class TypeFactory { Type *get_pointer_type(Type *element, bool is_bit_pointer = false); - Type *get_custom_int_type(int num_bits, bool is_signed); - - Type *get_custom_int_type_with_compute_type(int compute_type_bits, - int num_bits, - bool is_signed); + Type *get_custom_int_type(int num_bits, + bool is_signed, + int compute_type_bits=32); Type *get_bit_struct_type(PrimitiveType *physical_type, std::vector member_types, diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 6607494ad239c..e770a00d0e743 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -728,9 +728,7 @@ void export_lang(py::module &m) { // TypeFactory on Python-scope pointer destruction. py::class_(m, "TypeFactory") .def("get_custom_int_type", &TypeFactory::get_custom_int_type, - py::return_value_policy::reference) - .def("get_custom_int_type_with_compute_type", - &TypeFactory::get_custom_int_type_with_compute_type, + py::arg("num_bits"), py::arg("is_signed"),py::arg("compute_type_bits")=32, py::return_value_policy::reference); m.def("get_type_factory_instance", TypeFactory::get_instance, diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index abf239a7dda42..9cc85a5c64926 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -85,12 +85,12 @@ def test_single_bit_struct(physical_type, compute_type, custom_bits, print_ir=False, cfg_optimization=False) - cit1 = ti.type_factory_.get_custom_int_type_with_compute_type( - compute_type, custom_bits[0], True) - cit2 = ti.type_factory_.get_custom_int_type_with_compute_type( - compute_type, custom_bits[1], False) - cit3 = ti.type_factory_.get_custom_int_type_with_compute_type( - compute_type, custom_bits[2], True) + cit1 = ti.type_factory_.get_custom_int_type( + custom_bits[0], True, compute_type) + cit2 = ti.type_factory_.get_custom_int_type( + custom_bits[1], False, compute_type) + cit3 = ti.type_factory_.get_custom_int_type( + custom_bits[2], True, compute_type) a = ti.field(dtype=cit1) b = ti.field(dtype=cit2) From c9993cf81961d8db3c79727fb2adbe7fe4ca63c2 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Thu, 19 Nov 2020 22:16:03 -0500 Subject: [PATCH 30/32] [skip ci] enforce code format --- taichi/ir/type.cpp | 5 +++-- taichi/ir/type.h | 6 ++++-- taichi/ir/type_factory.cpp | 6 +++--- taichi/ir/type_factory.h | 2 +- taichi/python/export_lang.cpp | 3 ++- tests/python/test_bit_struct.py | 12 ++++++------ 6 files changed, 19 insertions(+), 15 deletions(-) diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index be0bf2859aa48..c12ae63544a67 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -103,7 +103,7 @@ std::string CustomIntType::to_string() const { CustomIntType::CustomIntType(int num_bits, bool is_signed, - Type* compute_type, + Type *compute_type, Type *physical_type) : compute_type(compute_type), physical_type(physical_type), @@ -111,7 +111,8 @@ CustomIntType::CustomIntType(int num_bits, is_signed_(is_signed) { if (compute_type == nullptr) { auto type_id = is_signed ? PrimitiveTypeID::i32 : PrimitiveTypeID::u32; - this->compute_type = TypeFactory::get_instance().get_primitive_type(type_id); + this->compute_type = + TypeFactory::get_instance().get_primitive_type(type_id); } } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index b3a19bae25583..4ab63fbce311b 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -166,8 +166,10 @@ class VectorType : public Type { class CustomIntType : public Type { public: - CustomIntType(int num_bits, bool is_signed, - Type* compute_type=nullptr, Type* physical_type=nullptr); + CustomIntType(int num_bits, + bool is_signed, + Type *compute_type = nullptr, + Type *physical_type = nullptr); ~CustomIntType() override { delete compute_type; diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 0a28be1561981..23b917937a441 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -43,9 +43,9 @@ Type *TypeFactory::get_custom_int_type(int num_bits, auto key = std::make_tuple(compute_type_bits, num_bits, is_signed); if (custom_int_types_with_compute_types_.find(key) == custom_int_types_with_compute_types_.end()) { - custom_int_types_with_compute_types_[key] = - std::make_unique(num_bits, is_signed, - get_primitive_int_type(compute_type_bits, is_signed)); + custom_int_types_with_compute_types_[key] = std::make_unique( + num_bits, is_signed, + get_primitive_int_type(compute_type_bits, is_signed)); } return custom_int_types_with_compute_types_[key].get(); } diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index a11d27df68178..c31e321395dc0 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -23,7 +23,7 @@ class TypeFactory { Type *get_custom_int_type(int num_bits, bool is_signed, - int compute_type_bits=32); + int compute_type_bits = 32); Type *get_bit_struct_type(PrimitiveType *physical_type, std::vector member_types, diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index e770a00d0e743..43d0b090e33b0 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -728,7 +728,8 @@ void export_lang(py::module &m) { // TypeFactory on Python-scope pointer destruction. py::class_(m, "TypeFactory") .def("get_custom_int_type", &TypeFactory::get_custom_int_type, - py::arg("num_bits"), py::arg("is_signed"),py::arg("compute_type_bits")=32, + py::arg("num_bits"), py::arg("is_signed"), + py::arg("compute_type_bits") = 32, py::return_value_policy::reference); m.def("get_type_factory_instance", TypeFactory::get_instance, diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index 9cc85a5c64926..8f0d1aeddb919 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -85,12 +85,12 @@ def test_single_bit_struct(physical_type, compute_type, custom_bits, print_ir=False, cfg_optimization=False) - cit1 = ti.type_factory_.get_custom_int_type( - custom_bits[0], True, compute_type) - cit2 = ti.type_factory_.get_custom_int_type( - custom_bits[1], False, compute_type) - cit3 = ti.type_factory_.get_custom_int_type( - custom_bits[2], True, compute_type) + cit1 = ti.type_factory_.get_custom_int_type(custom_bits[0], True, + compute_type) + cit2 = ti.type_factory_.get_custom_int_type(custom_bits[1], False, + compute_type) + cit3 = ti.type_factory_.get_custom_int_type(custom_bits[2], True, + compute_type) a = ti.field(dtype=cit1) b = ti.field(dtype=cit2) From 29f68ad0d0f0494f20fb62f8bae3bd2754dc1c1b Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Fri, 20 Nov 2020 11:57:14 +0800 Subject: [PATCH 31/32] update --- taichi/ir/type_factory.cpp | 7 +++---- taichi/ir/type_factory.h | 5 +---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 23b917937a441..7605acb604955 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -41,13 +41,12 @@ Type *TypeFactory::get_custom_int_type(int num_bits, bool is_signed, int compute_type_bits) { auto key = std::make_tuple(compute_type_bits, num_bits, is_signed); - if (custom_int_types_with_compute_types_.find(key) == - custom_int_types_with_compute_types_.end()) { - custom_int_types_with_compute_types_[key] = std::make_unique( + if (custom_int_types.find(key) == custom_int_types.end()) { + custom_int_types[key] = std::make_unique( num_bits, is_signed, get_primitive_int_type(compute_type_bits, is_signed)); } - return custom_int_types_with_compute_types_[key].get(); + return custom_int_types[key].get(); } #undef SET_COMPUTE_TYPE diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index c31e321395dc0..44b08e248a106 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -49,10 +49,7 @@ class TypeFactory { std::map, std::unique_ptr> pointer_types_; // TODO: use unordered map - std::map, std::unique_ptr> custom_int_types_; - - std::map, std::unique_ptr> - custom_int_types_with_compute_types_; + std::map, std::unique_ptr> custom_int_types; // TODO: avoid duplication std::vector> bit_struct_types_; From 985f35ef4cd078660c0049fed6e17715761697ba Mon Sep 17 00:00:00 2001 From: liujiafeng Date: Fri, 20 Nov 2020 12:43:53 +0800 Subject: [PATCH 32/32] remove todo --- taichi/codegen/codegen_llvm.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 28fe0900de16f..2d88aee5004aa 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -327,7 +327,6 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { TI_ASSERT(!to->is()); auto from_size = 0; if (from->is()) { - // TODO: replace 32 with a customizable type from_size = data_type_size(from->cast()->get_compute_type()); } else {