diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index 361bedc6c49d3..7b71f925a798d 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -3,6 +3,7 @@ from .matrix import Matrix, Vector from .transformer import TaichiSyntaxError from .ndrange import ndrange, GroupedNDRange +from .type_factory import TypeFactory from copy import deepcopy as _deepcopy import functools import os @@ -46,9 +47,12 @@ kernel_profiler_total_time = lambda: get_runtime( ).prog.kernel_profiler_total_time() -# Unstable API +# Legacy API type_factory_ = core.get_type_factory_instance() +# Unstable API +type_factory = TypeFactory() + def memory_profiler_print(): get_runtime().materialize() diff --git a/python/taichi/lang/snode.py b/python/taichi/lang/snode.py index d9a4272941037..586c7ed9998b9 100644 --- a/python/taichi/lang/snode.py +++ b/python/taichi/lang/snode.py @@ -42,7 +42,7 @@ def _bit_array(self, indices, dimensions, num_bits): dimensions = [dimensions] * len(indices) return SNode(self.ptr.bit_array(indices, dimensions, num_bits)) - def place(self, *args, offset=None): + def place(self, *args, offset=None, shared_exponent=False): from .expr import Expr from .util import is_taichi_class if offset is None: @@ -50,6 +50,9 @@ def place(self, *args, offset=None): if isinstance(offset, numbers.Number): offset = (offset, ) for arg in args: + assert shared_exponent == False + # TODO: implement shared exponent + if isinstance(arg, Expr): self.ptr.place(Expr(arg).ptr, offset) elif isinstance(arg, list): diff --git a/python/taichi/lang/type_factory.py b/python/taichi/lang/type_factory.py new file mode 100644 index 0000000000000..370de96fab913 --- /dev/null +++ b/python/taichi/lang/type_factory.py @@ -0,0 +1,20 @@ +class TypeFactory: + def __init__(self): + from taichi.core import ti_core + self.core = ti_core.get_type_factory_instance() + + def custom_int(self, bits, signed=True): + return self.core.get_custom_int_type(bits, signed) + + def custom_float(self, + significand_type, + exponent_type=None, + compute_type=None, + scale=1.0): + import taichi as ti + if compute_type is None: + compute_type = ti.get_runtime().default_fp.get_ptr() + return self.core.get_custom_float_type(significand_type, + exponent_type, + compute_type, + scale=scale) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 5f6a7b2116716..1ab9fe5a418b6 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1172,6 +1172,24 @@ void CodeGenLLVM::visit(GlobalPtrStmt *stmt) { TI_ERROR("Global Ptrs should have been lowered."); } +void CodeGenLLVM::store_custom_int(llvm::Value *bit_ptr, + CustomIntType *cit, + llvm::Value *value) { + llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr; + read_bit_pointer(bit_ptr, byte_ptr, bit_offset); + // TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers. + // Try to support CustomInt/FloatType with 8/16-bit physical + // types. + + create_call(fmt::format("set_partial_bits_b{}", + data_type_bits(cit->get_physical_type())), + {builder->CreateBitCast(byte_ptr, + llvm_ptr_type(cit->get_physical_type())), + bit_offset, tlctx->get_constant(cit->get_num_bits()), + builder->CreateIntCast( + value, llvm_type(cit->get_physical_type()), false)}); +} + void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { TI_ASSERT(!stmt->parent->mask() || stmt->width() == 1); TI_ASSERT(llvm_val[stmt->data]); @@ -1185,33 +1203,75 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { cit = cit_; store_value = llvm_val[stmt->data]; } else if (auto cft = pointee_type->cast()) { - cit = cft->get_digits_type()->as(); - store_value = float_to_custom_int(cft, cit, llvm_val[stmt->data]); + llvm::Value *digit_bits = nullptr; + auto digits_cit = cft->get_digits_type()->as(); + if (auto exp = cft->get_exponent_type()) { + // Extract exponent and digits from compute type (assumed to be f32 for + // now). + TI_ASSERT(cft->get_compute_type()->is_primitive(PrimitiveTypeID::f32)); + + auto f32_bits = builder->CreateBitCast( + llvm_val[stmt->data], llvm::Type::getInt32Ty(*llvm_context)); + auto exponent_bits = builder->CreateAShr(f32_bits, 23); + exponent_bits = builder->CreateAnd(exponent_bits, + tlctx->get_constant((1 << 8) - 1)); + // f32 = 1 sign bit + 8 exponent bits + 23 fraction bits + auto value_bits = builder->CreateAShr( + f32_bits, tlctx->get_constant(23 - cft->get_digit_bits())); + + digit_bits = builder->CreateAnd( + value_bits, + tlctx->get_constant((1 << (cft->get_digit_bits())) - 1)); + + if (cft->get_is_signed()) { + // extract the sign bit + auto sign_bit = + builder->CreateAnd(f32_bits, tlctx->get_constant(0x80000000u)); + // insert the sign bit to digit bits + digit_bits = builder->CreateOr( + digit_bits, + builder->CreateLShr(sign_bit, 31 - cft->get_digit_bits())); + } + + auto exponent_cit = exp->as(); + + auto digits_snode = stmt->ptr->as()->output_snode; + auto exponent_snode = digits_snode->exp_snode; + + // Since we have fewer bits in the exponent type than in f32, an + // offset is necessary to make sure the stored exponent values are + // representable by the exponent custom int type. + exponent_bits = builder->CreateSub( + exponent_bits, + tlctx->get_constant(cft->get_exponent_conversion_offset())); + + // Compute the bit pointer of the exponent bits. + TI_ASSERT(digits_snode->parent == exponent_snode->parent); + auto exponent_bit_ptr = + offset_bit_ptr(llvm_val[stmt->ptr], exponent_snode->bit_offset - + digits_snode->bit_offset); + store_custom_int(exponent_bit_ptr, exponent_cit, exponent_bits); + store_value = digit_bits; + } else { + digit_bits = llvm_val[stmt->data]; + store_value = float_to_custom_int(cft, digits_cit, digit_bits); + } + cit = digits_cit; } else { TI_NOT_IMPLEMENTED } - llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr; - read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); - // TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers. - // Try to support CustomInt/FloatType with 8/16-bit physical - // types. - create_call(fmt::format("set_partial_bits_b{}", - data_type_bits(cit->get_physical_type())), - {builder->CreateBitCast( - byte_ptr, llvm_ptr_type(cit->get_physical_type())), - bit_offset, tlctx->get_constant(cit->get_num_bits()), - builder->CreateIntCast( - store_value, llvm_type(cit->get_physical_type()), false)}); + store_custom_int(llvm_val[stmt->ptr], cit, store_value); } else { builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); } } -llvm::Value *CodeGenLLVM::load_as_custom_int(Stmt *ptr, Type *load_type) { +llvm::Value *CodeGenLLVM::load_as_custom_int(llvm::Value *ptr, + Type *load_type) { auto *cit = load_type->as(); // load bit pointer llvm::Value *byte_ptr, *bit_offset; - read_bit_pointer(llvm_val[ptr], byte_ptr, bit_offset); + read_bit_pointer(ptr, byte_ptr, bit_offset); auto bit_level_container = builder->CreateLoad(builder->CreateBitCast( byte_ptr, llvm_ptr_type(cit->get_physical_type()))); @@ -1264,6 +1324,55 @@ llvm::Value *CodeGenLLVM::reconstruct_custom_float(llvm::Value *digits, return builder->CreateFMul(cast, s); } +llvm::Value *CodeGenLLVM::load_custom_float_with_exponent( + llvm::Value *digits_bit_ptr, + llvm::Value *exponent_bit_ptr, + CustomFloatType *cft) { + // TODO: we ignore "scale" for CustomFloatType with exponent for now. Fix + // this. + TI_ASSERT(cft->get_scale() == 1); + auto digits = load_as_custom_int(digits_bit_ptr, cft->get_digits_type()); + + auto exponent_val = load_as_custom_int( + exponent_bit_ptr, cft->get_exponent_type()->as()); + + // Make sure the exponent is within the range of the exponent type + exponent_val = builder->CreateAdd( + exponent_val, tlctx->get_constant(cft->get_exponent_conversion_offset())); + + if (cft->get_compute_type()->is_primitive(PrimitiveTypeID::f32)) { + // Construct an f32 out of exponent_val and digits + // Assuming digits and exponent_val are i32 + // f32 = 1 sign bit + 8 exponent bits + 23 fraction bits + auto exponent_bits = + builder->CreateShl(exponent_val, tlctx->get_constant(23)); + + digits = builder->CreateAnd( + digits, + (1u << cft->get_digits_type()->as()->get_num_bits()) - + 1); + digits = builder->CreateShl( + digits, tlctx->get_constant(23 - cft->get_digit_bits())); + + auto fraction_bits = builder->CreateAnd(digits, (1u << 23) - 1); + + auto f32_bits = builder->CreateOr(exponent_bits, fraction_bits); + + if (cft->get_is_signed()) { + auto sign_bit = + builder->CreateAnd(digits, tlctx->get_constant(1u << (23))); + + sign_bit = builder->CreateShl(sign_bit, tlctx->get_constant(31 - (23))); + f32_bits = builder->CreateOr(f32_bits, sign_bit); + } + + return builder->CreateBitCast(f32_bits, + llvm::Type::getFloatTy(*llvm_context)); + } else { + TI_NOT_IMPLEMENTED; + } +} + void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { int width = stmt->width(); TI_ASSERT(width == 1); @@ -1271,10 +1380,26 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { if (ptr_type->is_bit_pointer()) { auto val_type = ptr_type->get_pointee_type(); if (val_type->is()) { - llvm_val[stmt] = load_as_custom_int(stmt->ptr, val_type); + llvm_val[stmt] = load_as_custom_int(llvm_val[stmt->ptr], val_type); } else if (auto cft = val_type->cast()) { - auto digits = load_as_custom_int(stmt->ptr, cft->get_digits_type()); - llvm_val[stmt] = reconstruct_custom_float(digits, val_type); + if (cft->get_exponent_type()) { + auto ptr = stmt->ptr->as(); + TI_ASSERT(ptr->width() == 1); + auto digits_bit_ptr = llvm_val[ptr]; + auto digits_snode = ptr->output_snode; + auto exponent_snode = digits_snode->exp_snode; + // Compute the bit pointer of the exponent bits. + TI_ASSERT(digits_snode->parent == exponent_snode->parent); + auto exponent_bit_ptr = + offset_bit_ptr(digits_bit_ptr, exponent_snode->bit_offset - + digits_snode->bit_offset); + llvm_val[stmt] = load_custom_float_with_exponent(digits_bit_ptr, + exponent_bit_ptr, cft); + } else { + auto digits = + load_as_custom_int(llvm_val[stmt->ptr], cft->get_digits_type()); + llvm_val[stmt] = reconstruct_custom_float(digits, val_type); + } } else { TI_NOT_IMPLEMENTED } @@ -1374,7 +1499,7 @@ void CodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED} llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base, llvm::Value *bit_offset) { - // 1. create a bit pointer struct + // 1. get the bit pointer LLVM struct // struct bit_pointer { // i8* byte_ptr; // i32 offset; @@ -1383,21 +1508,37 @@ llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base, *llvm_context, {llvm::Type::getInt8PtrTy(*llvm_context), llvm::Type::getInt32Ty(*llvm_context), llvm::Type::getInt32Ty(*llvm_context)}); - // 2. alloca the bit pointer struct + // 2. allocate the bit pointer struct auto bit_ptr_struct = create_entry_block_alloca(struct_type); - // 3. store `input_ptr` into `bit_ptr_struct` - auto byte_ptr = builder->CreateBitCast( - byte_ptr_base, llvm::PointerType::getInt8PtrTy(*llvm_context)); - builder->CreateStore( - byte_ptr, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0), - tlctx->get_constant(0)})); - // 4. store `offset` in `bit_ptr_struct` - builder->CreateStore( - bit_offset, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0), - tlctx->get_constant(1)})); + // 3. store `byte_ptr_base` into `bit_ptr_struct` (if provided) + if (byte_ptr_base) { + auto byte_ptr = builder->CreateBitCast( + byte_ptr_base, llvm::PointerType::getInt8PtrTy(*llvm_context)); + builder->CreateStore( + byte_ptr, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0), + tlctx->get_constant(0)})); + } + // 4. store `offset` in `bit_ptr_struct` (if provided) + if (bit_offset) { + builder->CreateStore( + bit_offset, + builder->CreateGEP(bit_ptr_struct, + {tlctx->get_constant(0), tlctx->get_constant(1)})); + } return bit_ptr_struct; } +llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *input_bit_ptr, + int bit_offset_delta) { + auto byte_ptr_base = builder->CreateLoad(builder->CreateGEP( + input_bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)})); + auto input_offset = builder->CreateLoad(builder->CreateGEP( + input_bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)})); + auto new_bit_offset = + builder->CreateAdd(input_offset, tlctx->get_constant(bit_offset_delta)); + return create_bit_ptr_struct(byte_ptr_base, new_bit_offset); +} + void CodeGenLLVM::visit(SNodeLookupStmt *stmt) { llvm::Value *parent = nullptr; parent = llvm_val[stmt->input_snode]; diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 8786fbad97aef..896282a5c352e 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -205,9 +205,13 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(GlobalPtrStmt *stmt) override; + void store_custom_int(llvm::Value *bit_ptr, + CustomIntType *cit, + llvm::Value *value); + void visit(GlobalStoreStmt *stmt) override; - llvm::Value *load_as_custom_int(Stmt *ptr, Type *load_type); + llvm::Value *load_as_custom_int(llvm::Value *ptr, Type *load_type); llvm::Value *extract_custom_int(llvm::Value *physical_value, llvm::Value *bit_offset, @@ -215,6 +219,10 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *reconstruct_custom_float(llvm::Value *digits, Type *load_type); + llvm::Value *load_custom_float_with_exponent(llvm::Value *digits_bit_ptr, + llvm::Value *exponent_bit_ptr, + CustomFloatType *cft); + void visit(GlobalLoadStmt *stmt) override; void visit(ElementShuffleStmt *stmt) override; @@ -227,8 +235,10 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(IntegerOffsetStmt *stmt) override; - llvm::Value *create_bit_ptr_struct(llvm::Value *byte_ptr_base, - llvm::Value *bit_offset); + llvm::Value *create_bit_ptr_struct(llvm::Value *byte_ptr_base = nullptr, + llvm::Value *bit_offset = nullptr); + + llvm::Value *offset_bit_ptr(llvm::Value *input_bit_ptr, int bit_offset_delta); void visit(SNodeLookupStmt *stmt) override; diff --git a/taichi/ir/snode.cpp b/taichi/ir/snode.cpp index 73ed31c7c0465..399984def296b 100644 --- a/taichi/ir/snode.cpp +++ b/taichi/ir/snode.cpp @@ -42,6 +42,17 @@ void SNode::place(Expr &expr_, const std::vector &offset) { TI_ASSERT(expr_.is()); auto expr = expr_.cast(); TI_ERROR_IF(expr->snode != nullptr, "This variable has been placed."); + SNode *new_exp_snode = nullptr; + if (auto cft = expr->dt->cast()) { + if (auto exp = cft->get_exponent_type()) { + // Non-empty exponent type. First create a place SNode for the + // exponent value. + auto &exp_node = insert_children(SNodeType::place); + exp_node.dt = exp; + exp_node.name = expr->ident.raw_name() + "_exp"; + new_exp_snode = &exp_node; + } + } auto &child = insert_children(SNodeType::place); expr->set_snode(&child); child.name = expr->ident.raw_name(); @@ -51,6 +62,9 @@ void SNode::place(Expr &expr_, const std::vector &offset) { } expr->snode->expr.set(Expr(expr)); child.dt = expr->dt; + if (new_exp_snode) { + child.exp_snode = new_exp_snode; + } if (!offset.empty()) child.set_index_offsets(offset); } @@ -282,7 +296,11 @@ void SNode::print() { for (int i = 0; i < depth; i++) { fmt::print(" "); } - fmt::print("{}\n", get_node_type_name_hinted()); + fmt::print("{}", get_node_type_name_hinted()); + if (exp_snode) { + fmt::print(" exp={}\n", exp_snode->get_node_type_name()); + } + fmt::print("\n"); for (auto &c : ch) { c->print(); } diff --git a/taichi/ir/snode.h b/taichi/ir/snode.h index 6444c5f4d00d0..da559d3575883 100644 --- a/taichi/ir/snode.h +++ b/taichi/ir/snode.h @@ -85,6 +85,8 @@ class SNode { Kernel *reader_kernel{}; Kernel *writer_kernel{}; Expr expr; + SNode *exp_snode{}; // for CustomFloatType with exponent bits + int bit_offset{0}; // for children of bit_struct only // is_bit_level=false: the SNode is not bitpacked // is_bit_level=true: the SNode is bitpacked (i.e., strictly inside bit_struct diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 3500d72e977fd..3baf61a9c411c 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -116,12 +116,25 @@ CustomIntType::CustomIntType(int num_bits, } CustomFloatType::CustomFloatType(Type *digits_type, + Type *exponent_type, Type *compute_type, float64 scale) - : digits_type_(digits_type), compute_type_(compute_type), scale_(scale) { + : digits_type_(digits_type), + exponent_type_(exponent_type), + compute_type_(compute_type), + scale_(scale) { TI_ASSERT(digits_type->is()); TI_ASSERT(compute_type->is()); TI_ASSERT(is_real(compute_type->as())); + + if (exponent_type_) { + // We only support f32 as compute type when when using exponents + TI_ASSERT(compute_type_->is_primitive(PrimitiveTypeID::f32)); + // Exponent must be unsigned custom int + TI_ASSERT(exponent_type->is()); + TI_ASSERT(exponent_type->as()->get_num_bits() <= 8); + TI_ASSERT(get_digit_bits() <= 23); + } } std::string CustomFloatType::to_string() const { @@ -129,6 +142,21 @@ std::string CustomFloatType::to_string() const { compute_type_->to_string(), scale_); } +int CustomFloatType::get_exponent_conversion_offset() const { + // Note that f32 has exponent offset -127 + return 127 - + (1 << (exponent_type_->as()->get_num_bits() - 1)) + 1; +} + +int CustomFloatType::get_digit_bits() const { + return digits_type_->as()->get_num_bits() - + (int)get_is_signed(); +} + +bool CustomFloatType::get_is_signed() const { + return digits_type_->as()->get_is_signed(); +} + BitStructType::BitStructType(PrimitiveType *physical_type, std::vector member_types, std::vector member_bit_offsets) diff --git a/taichi/ir/type.h b/taichi/ir/type.h index f6005eb8b558b..e978b2054fc3d 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -212,7 +212,10 @@ class CustomIntType : public Type { class CustomFloatType : public Type { public: - CustomFloatType(Type *digits_type, Type *compute_type, float64 scale); + CustomFloatType(Type *digits_type, + Type *exponent_type, + Type *compute_type, + float64 scale); std::string to_string() const override; @@ -224,12 +227,23 @@ class CustomFloatType : public Type { return digits_type_; } + Type *get_exponent_type() { + return exponent_type_; + } + + int get_exponent_conversion_offset() const; + + int get_digit_bits() const; + + bool get_is_signed() const; + Type *get_compute_type() override { return compute_type_; } private: Type *digits_type_{nullptr}; + Type *exponent_type_{nullptr}; Type *compute_type_{nullptr}; float64 scale_; }; diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 8fc6fa7b797ea..cac8533110e83 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -50,12 +50,13 @@ Type *TypeFactory::get_custom_int_type(int num_bits, } Type *TypeFactory::get_custom_float_type(Type *digits_type, + Type *exponent_type, Type *compute_type, float64 scale) { - auto key = std::make_tuple(digits_type, compute_type, scale); + auto key = std::make_tuple(digits_type, exponent_type, compute_type, scale); if (custom_float_types.find(key) == custom_float_types.end()) { - custom_float_types[key] = - std::make_unique(digits_type, compute_type, scale); + custom_float_types[key] = std::make_unique( + digits_type, exponent_type, compute_type, scale); } return custom_float_types[key].get(); } diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index f1e6583ed831b..b724ab6fefbdb 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -26,6 +26,7 @@ class TypeFactory { int compute_type_bits = 32); Type *get_custom_float_type(Type *digits_type, + Type *exponent_type, Type *compute_type, float64 scale); @@ -56,7 +57,7 @@ class TypeFactory { std::map, std::unique_ptr> custom_int_types; // TODO: use unordered map - std::map, std::unique_ptr> + std::map, std::unique_ptr> custom_float_types; // TODO: avoid duplication diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index d2a6c7bd2cacb..27a3ad0aaaf94 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -68,11 +68,19 @@ real measure_cpe(std::function target, std::string data_type_format(DataType dt) { if (dt->is_primitive(PrimitiveTypeID::i32)) { return "%d"; + } else if (dt->is_primitive(PrimitiveTypeID::u32)) { + return "%u"; } else if (dt->is_primitive(PrimitiveTypeID::i64)) { #if defined(TI_PLATFORM_UNIX) return "%lld"; #else return "%I64d"; +#endif + } else if (dt->is_primitive(PrimitiveTypeID::u64)) { +#if defined(TI_PLATFORM_UNIX) + return "%llu"; +#else + return "%I64u"; #endif } else if (dt->is_primitive(PrimitiveTypeID::f32)) { return "%f"; @@ -87,8 +95,8 @@ std::string data_type_format(DataType dt) { int data_type_size(DataType t) { // TODO: - // 1. Ensure in the old code, pointer attributes of t are correct (by setting - // a loud failure on pointers); + // 1. Ensure in the old code, pointer attributes of t are correct (by + // setting a loud failure on pointers); // 2. Support pointer types here. t.set_is_pointer(false); if (false) { diff --git a/taichi/llvm/llvm_codegen_utils.h b/taichi/llvm/llvm_codegen_utils.h index bd9c1e717d334..404c7dcb36c6a 100644 --- a/taichi/llvm/llvm_codegen_utils.h +++ b/taichi/llvm/llvm_codegen_utils.h @@ -126,6 +126,7 @@ class LLVMModuleBuilder { return call(this->builder.get(), func_name, std::forward(args)...); } + // TODO(type): return with std::tuple void read_bit_pointer(llvm::Value *ptr, llvm::Value *&byte_ptr, llvm::Value *&bit_offset) { diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 505bd489389d6..a1c43ea570b36 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -741,7 +741,8 @@ void export_lang(py::module &m) { py::arg("compute_type_bits") = 32, py::return_value_policy::reference) .def("get_custom_float_type", &TypeFactory::get_custom_float_type, - py::arg("digits_type"), py::arg("compute_type"), py::arg("scale"), + py::arg("digits_type"), py::arg("exponent_type"), + py::arg("compute_type"), py::arg("scale"), py::return_value_policy::reference); m.def("get_type_factory_instance", TypeFactory::get_instance, diff --git a/taichi/struct/struct_llvm.cpp b/taichi/struct/struct_llvm.cpp index 0e60037f00dec..a9a3804fdde8f 100644 --- a/taichi/struct/struct_llvm.cpp +++ b/taichi/struct/struct_llvm.cpp @@ -73,14 +73,15 @@ void StructCompilerLLVM::generate_types(SNode &snode) { } else if (auto cft = ch->dt->cast()) { component_cit = cft->get_digits_type()->as(); } else { - TI_NOT_IMPLEMENTED + TI_ERROR("Type {} not supported.", ch->dt->to_string()); } component_cit->set_physical_type(snode.physical_type); if (!arch_is_cpu(arch)) { - TI_ERROR_IF(data_type_bits(snode.physical_type) <= 16, + TI_ERROR_IF(data_type_bits(snode.physical_type) < 32, "bit_struct physical type must be at least 32 bits on " "non-CPU backends."); } + ch->bit_offset = total_offset; total_offset += component_cit->get_num_bits(); } diff --git a/tests/python/test_custom_float.py b/tests/python/test_custom_float.py index b661fc709af74..c68f4676081fc 100644 --- a/tests/python/test_custom_float.py +++ b/tests/python/test_custom_float.py @@ -5,8 +5,8 @@ @ti.test(require=ti.extension.quant) def test_custom_float(): - ci13 = ti.type_factory_.get_custom_int_type(13, True) - cft = ti.type_factory_.get_custom_float_type(ci13, ti.f32.get_ptr(), 0.1) + ci13 = ti.type_factory.custom_int(bits=13) + cft = ti.type_factory.custom_float(significand_type=ci13, scale=0.1) x = ti.field(dtype=cft) ti.root._bit_struct(num_bits=32).place(x) @@ -27,9 +27,9 @@ def foo(): @ti.test(require=ti.extension.quant) def test_custom_matrix_rotation(): - ci16 = ti.type_factory_.get_custom_int_type(16, True) - cft = ti.type_factory_.get_custom_float_type(ci16, ti.f32.get_ptr(), - 1.2 / (2**15)) + ci16 = ti.type_factory.custom_int(bits=16) + cft = ti.type_factory.custom_float(significand_type=ci16, + scale=1.2 / (2**15)) x = ti.Matrix.field(2, 2, dtype=cft) @@ -55,8 +55,8 @@ def rotate_18_degrees(): @ti.test(require=ti.extension.quant) def test_custom_float_implicit_cast(): - ci13 = ti.type_factory_.get_custom_int_type(13, True) - cft = ti.type_factory_.get_custom_float_type(ci13, ti.f32.get_ptr(), 0.1) + ci13 = ti.type_factory.custom_int(bits=13) + cft = ti.type_factory.custom_float(significand_type=ci13, scale=0.1) x = ti.field(dtype=cft) ti.root._bit_struct(num_bits=32).place(x) @@ -71,8 +71,8 @@ def foo(): @ti.test(require=ti.extension.quant) def test_cache_read_only(): - ci15 = ti.type_factory_.get_custom_int_type(15, True) - cft = ti.type_factory_.get_custom_float_type(ci15, ti.f32.get_ptr(), 0.1) + ci15 = ti.type_factory.custom_int(bits=15) + cft = ti.type_factory.custom_float(significand_type=ci15, scale=0.1) x = ti.field(dtype=cft) ti.root._bit_struct(num_bits=32).place(x) diff --git a/tests/python/test_custom_float_exponents.py b/tests/python/test_custom_float_exponents.py new file mode 100644 index 0000000000000..53ec332d80f01 --- /dev/null +++ b/tests/python/test_custom_float_exponents.py @@ -0,0 +1,103 @@ +import taichi as ti +import numpy as np +import pytest + + +@ti.test(require=ti.extension.quant) +def test_custom_float_unsigned(): + cu13 = ti.type_factory.custom_int(13, False) + exp = ti.type_factory.custom_int(6, False) + cft = ti.type_factory.custom_float(significand_type=cu13, + exponent_type=exp, + scale=1) + x = ti.field(dtype=cft) + + ti.root._bit_struct(num_bits=32).place(x) + + tests = [ + 1 / 1024, 1.75 / 1024, 0.25, 0.5, 0.75, 1, 2, 3, 4, 5, 6, 7, 128, 256, + 512, 1024 + ] + + for v in tests: + x[None] = v + assert x[None] == v + + +@ti.test(require=ti.extension.quant) +def test_custom_float_signed(): + cu13 = ti.type_factory.custom_int(13, True) + exp = ti.type_factory.custom_int(6, False) + cft = ti.type_factory.custom_float(significand_type=cu13, + exponent_type=exp, + scale=1) + x = ti.field(dtype=cft) + + ti.root._bit_struct(num_bits=32).place(x) + + tests = [-0.125, -0.5, -2, -4, -6, -7, -8, -9] + + for v in tests: + x[None] = v + assert x[None] == v + + x[None] = -v + assert x[None] == -v + + +@pytest.mark.parametrize('digits_bits', [23, 24]) +@ti.test(require=ti.extension.quant) +def test_custom_float_precision(digits_bits): + cu24 = ti.type_factory.custom_int(digits_bits, True) + exp = ti.type_factory.custom_int(8, False) + cft = ti.type_factory.custom_float(significand_type=cu24, + exponent_type=exp, + scale=1) + x = ti.field(dtype=cft) + + ti.root._bit_struct(num_bits=32).place(x) + + tests = [np.float32(np.pi), np.float32(np.pi * (1 << 100))] + + for v in tests: + x[None] = v + if digits_bits == 24: + # Sufficient digits + assert x[None] == v + else: + # The binary representation of np.float32(np.pi) ends with 1, so removing one digit will result in a different number. + assert x[None] != v + assert x[None] == pytest.approx(v, rel=3e-7) + + +@pytest.mark.parametrize('signed', [True, False]) +@ti.test(require=ti.extension.quant) +def test_custom_float_truncation(signed): + cit = ti.type_factory.custom_int(2, signed) + exp = ti.type_factory.custom_int(5, False) + cft = ti.type_factory.custom_float(significand_type=cit, + exponent_type=exp, + scale=1) + x = ti.field(dtype=cft) + + ti.root._bit_struct(num_bits=32).place(x) + + # Sufficient digits + for v in [1, 1.5]: + x[None] = v + assert x[None] == v + + x[None] = 1.75 + if signed: + # Insufficient digits + assert x[None] == 1.5 + else: + # Sufficient digits + assert x[None] == 1.75 + + # Insufficient digits + x[None] = 1.875 + if signed: + assert x[None] == 1.5 + else: + assert x[None] == 1.75 diff --git a/tests/python/test_custom_type_atomics.py b/tests/python/test_custom_type_atomics.py index 375de25397b70..aa150f9a79b17 100644 --- a/tests/python/test_custom_type_atomics.py +++ b/tests/python/test_custom_type_atomics.py @@ -64,8 +64,8 @@ def foo(): def test_custom_float_atomics(): ci13 = ti.type_factory_.get_custom_int_type(13, True) ci19 = ti.type_factory_.get_custom_int_type(19, False) - cft13 = ti.type_factory_.get_custom_float_type(ci13, ti.f32.get_ptr(), 0.1) - cft19 = ti.type_factory_.get_custom_float_type(ci19, ti.f32.get_ptr(), 0.1) + cft13 = ti.type_factory.custom_float(significand_type=ci13, scale=0.1) + cft19 = ti.type_factory.custom_float(significand_type=ci19, scale=0.1) x = ti.field(dtype=cft13) y = ti.field(dtype=cft19)