diff --git a/taichi/backends/cuda/codegen_cuda.cpp b/taichi/backends/cuda/codegen_cuda.cpp index e4785717d28bf..6b8bd6f191e0a 100644 --- a/taichi/backends/cuda/codegen_cuda.cpp +++ b/taichi/backends/cuda/codegen_cuda.cpp @@ -321,6 +321,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { fmt::format("cuda_rand_{}", data_type_short_name(stmt->ret_type)), {get_context()}); } + void visit(RangeForStmt *for_stmt) override { create_naive_range_for(for_stmt); } @@ -385,6 +386,21 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { return true; // on CUDA, pass the argument by value } + llvm::Value *create_intrinsic_load(const DataType &dtype, + llvm::Value *data_ptr) { + auto llvm_dtype = llvm_type(dtype); + auto llvm_dtype_ptr = llvm::PointerType::get(llvm_type(dtype), 0); + llvm::Intrinsic::ID intrin; + if (is_real(dtype)) { + intrin = llvm::Intrinsic::nvvm_ldg_global_f; + } else { + intrin = llvm::Intrinsic::nvvm_ldg_global_i; + } + return builder->CreateIntrinsic( + intrin, {llvm_dtype, llvm_dtype_ptr}, + {data_ptr, tlctx->get_constant(data_type_size(dtype))}); + } + void visit(GlobalLoadStmt *stmt) override { if (auto get_ch = stmt->ptr->cast(); get_ch) { bool should_cache_as_read_only = false; @@ -398,18 +414,30 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { // Issue an CUDA "__ldg" instruction so that data are cached in // the CUDA read-only data cache. auto dtype = stmt->ret_type; - auto llvm_dtype = llvm_type(dtype); - auto llvm_dtype_ptr = llvm::PointerType::get(llvm_type(dtype), 0); - llvm::Intrinsic::ID intrin; - if (is_real(dtype)) { - intrin = llvm::Intrinsic::nvvm_ldg_global_f; + if (auto ptr_type = stmt->ptr->ret_type->as(); + ptr_type->is_bit_pointer()) { + auto val_type = ptr_type->get_pointee_type(); + llvm::Value *data_ptr = nullptr; + llvm::Value *bit_offset = nullptr; + if (auto cit = val_type->cast()) { + dtype = cit->get_physical_type(); + } else if (auto cft = val_type->cast()) { + dtype = cft->get_compute_type() + ->as() + ->get_physical_type(); + } else { + TI_NOT_IMPLEMENTED; + } + read_bit_pointer(llvm_val[stmt->ptr], data_ptr, bit_offset); + data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype)); + auto data = create_intrinsic_load(dtype, data_ptr); + llvm_val[stmt] = extract_custom_int(data, bit_offset, val_type); + if (val_type->is()) { + llvm_val[stmt] = reconstruct_custom_float(llvm_val[stmt], val_type); + } } else { - intrin = llvm::Intrinsic::nvvm_ldg_global_i; + llvm_val[stmt] = create_intrinsic_load(dtype, llvm_val[stmt->ptr]); } - - llvm_val[stmt] = builder->CreateIntrinsic( - intrin, {llvm_dtype, llvm_dtype_ptr}, - {llvm_val[stmt->ptr], tlctx->get_constant(data_type_size(dtype))}); } else { CodeGenLLVM::visit(stmt); } diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index a3882125765e3..e8fdc0ed62eac 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1141,6 +1141,8 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { } llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr; read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset); + // TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers + // try to support CustomInt/Float Type with 16-bits or 8-bits physical type auto func_name = fmt::format("set_partial_bits_b{}", data_type_bits(cit->get_physical_type())); builder->CreateCall( @@ -1157,15 +1159,23 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { llvm::Value *CodeGenLLVM::load_as_custom_int(Stmt *ptr, Type *load_type) { auto *cit = load_type->as(); - // 1. load bit pointer + // load bit pointer llvm::Value *byte_ptr, *bit_offset; read_bit_pointer(llvm_val[ptr], byte_ptr, bit_offset); auto bit_level_container = builder->CreateLoad(builder->CreateBitCast( byte_ptr, llvm_ptr_type(cit->get_physical_type()))); - // 2. bit shifting + + return extract_custom_int(bit_level_container, bit_offset, load_type); +} + +llvm::Value *CodeGenLLVM::extract_custom_int(llvm::Value *physical_value, + llvm::Value *bit_offset, + Type *load_type) { + // bit shifting // first left shift `physical_type - (offset + num_bits)` // then right shift `physical_type - num_bits` + auto cit = load_type->as(); auto bit_end = builder->CreateAdd(bit_offset, tlctx->get_constant(cit->get_num_bits())); auto left = builder->CreateSub( @@ -1173,9 +1183,9 @@ llvm::Value *CodeGenLLVM::load_as_custom_int(Stmt *ptr, Type *load_type) { auto right = builder->CreateSub( tlctx->get_constant(data_type_bits(cit->get_physical_type())), tlctx->get_constant(cit->get_num_bits())); - left = builder->CreateIntCast(left, bit_level_container->getType(), false); - right = builder->CreateIntCast(right, bit_level_container->getType(), false); - auto step1 = builder->CreateShl(bit_level_container, left); + left = builder->CreateIntCast(left, physical_value->getType(), false); + right = builder->CreateIntCast(right, physical_value->getType(), false); + auto step1 = builder->CreateShl(physical_value, left); llvm::Value *step2 = nullptr; if (cit->get_is_signed()) @@ -1187,29 +1197,34 @@ llvm::Value *CodeGenLLVM::load_as_custom_int(Stmt *ptr, Type *load_type) { cit->get_is_signed()); } +llvm::Value *CodeGenLLVM::reconstruct_custom_float(llvm::Value *digits, + Type *load_type) { + // Compute float(digits) * scale + auto cft = load_type->as(); + llvm::Value *cast = nullptr; + auto compute_type = cft->get_compute_type()->as(); + if (cft->get_digits_type()->cast()->get_is_signed()) { + cast = builder->CreateSIToFP(digits, llvm_type(compute_type)); + } else { + cast = builder->CreateUIToFP(digits, llvm_type(compute_type)); + } + llvm::Value *s = + llvm::ConstantFP::get(*llvm_context, llvm::APFloat(cft->get_scale())); + s = builder->CreateFPCast(s, llvm_type(compute_type)); + return builder->CreateFMul(cast, s); +} + void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { int width = stmt->width(); TI_ASSERT(width == 1); - if (auto ptr_type = stmt->ptr->ret_type->cast(); - ptr_type->is_bit_pointer()) { + auto ptr_type = stmt->ptr->ret_type->as(); + if (ptr_type->is_bit_pointer()) { auto val_type = ptr_type->get_pointee_type(); if (val_type->is()) { llvm_val[stmt] = load_as_custom_int(stmt->ptr, val_type); } else if (auto cft = val_type->cast()) { auto digits = load_as_custom_int(stmt->ptr, cft->get_digits_type()); - // Compute float(digits) * scale - llvm::Value *cast = nullptr; - auto compute_type = cft->get_compute_type()->as(); - if (cft->get_digits_type()->cast()->get_is_signed()) { - cast = builder->CreateSIToFP(digits, llvm_type(compute_type)); - } else { - cast = builder->CreateUIToFP(digits, llvm_type(compute_type)); - } - llvm::Value *s = - llvm::ConstantFP::get(*llvm_context, llvm::APFloat(cft->get_scale())); - s = builder->CreateFPCast(s, llvm_type(compute_type)); - auto scaled = builder->CreateFMul(cast, s); - llvm_val[stmt] = scaled; + llvm_val[stmt] = reconstruct_custom_float(digits, val_type); } else { TI_NOT_IMPLEMENTED } diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 63fcbf673ec52..5a8bfc4161d05 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -198,6 +198,12 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *load_as_custom_int(Stmt *ptr, Type *load_type); + llvm::Value *extract_custom_int(llvm::Value *physical_value, + llvm::Value *bit_offset, + Type *load_type); + + llvm::Value *reconstruct_custom_float(llvm::Value *digits, Type *load_type); + void visit(GlobalLoadStmt *stmt) override; void visit(ElementShuffleStmt *stmt) override; diff --git a/tests/python/test_bit_array.py b/tests/python/test_bit_array.py index 4429eca2c9a7d..9d236072eed49 100644 --- a/tests/python/test_bit_array.py +++ b/tests/python/test_bit_array.py @@ -2,7 +2,7 @@ import numpy as np -@ti.test(arch=ti.cpu, debug=True, cfg_optimization=False) +@ti.test(require=ti.extension.quant, debug=True, cfg_optimization=False) def test_1D_bit_array(): ci1 = ti.type_factory_.get_custom_int_type(1, False) @@ -28,7 +28,7 @@ def verify_val(): verify_val() -@ti.test(arch=ti.cpu, debug=True, cfg_optimization=False) +@ti.test(require=ti.extension.quant, debug=True, cfg_optimization=False) def test_2D_bit_array(): ci1 = ti.type_factory_.get_custom_int_type(1, False) diff --git a/tests/python/test_bit_struct.py b/tests/python/test_bit_struct.py index 8f0d1aeddb919..1ea0e79e9c2a1 100644 --- a/tests/python/test_bit_struct.py +++ b/tests/python/test_bit_struct.py @@ -2,7 +2,7 @@ import numpy as np -@ti.test(arch=ti.cpu, debug=True, cfg_optimization=False) +@ti.test(require=ti.extension.quant, debug=True, cfg_optimization=False) def test_simple_array(): ci13 = ti.type_factory_.get_custom_int_type(13, True) cu19 = ti.type_factory_.get_custom_int_type(19, False) @@ -36,7 +36,7 @@ def verify_val(): verify_val.__wrapped__() -@ti.test(arch=ti.cpu, debug=True, cfg_optimization=False) +@ti.test(require=ti.extension.quant, debug=True, cfg_optimization=False) def test_custom_int_load_and_store(): ci13 = ti.type_factory_.get_custom_int_type(13, True) cu14 = ti.type_factory_.get_custom_int_type(14, False) diff --git a/tests/python/test_custom_float.py b/tests/python/test_custom_float.py index 776612b3070c9..f644245bc3b78 100644 --- a/tests/python/test_custom_float.py +++ b/tests/python/test_custom_float.py @@ -2,7 +2,7 @@ from pytest import approx -@ti.test(arch=ti.cpu, cfg_optimization=False) +@ti.test(require=ti.extension.quant, cfg_optimization=False) def test_custom_float(): ci13 = ti.type_factory_.get_custom_int_type(13, True) cft = ti.type_factory_.get_custom_float_type(ci13, ti.f32.get_ptr(), 0.1)