Skip to content

Commit

Permalink
[Lang] [type] Support placing QuantFixedType under quant_array
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier committed Jul 11, 2022
1 parent 3459901 commit 9b4a451
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 34 deletions.
4 changes: 2 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,9 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
if (auto get_ch = stmt->src->cast<GetChStmt>()) {
bool should_cache_as_read_only = current_offload->mem_access_opt.has_flag(
get_ch->output_snode, SNodeAccessFlag::read_only);
global_load(stmt, should_cache_as_read_only);
create_global_load(stmt, should_cache_as_read_only);
} else {
global_load(stmt, false);
create_global_load(stmt, false);
}
}

Expand Down
29 changes: 14 additions & 15 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,19 +1394,18 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
auto ptr_type = stmt->dest->ret_type->as<PointerType>();
if (ptr_type->is_bit_pointer()) {
auto pointee_type = ptr_type->get_pointee_type();
if (!pointee_type->is<QuantIntType>()) {
if (stmt->dest->as<GetChStmt>()->input_snode->type ==
SNodeType::bit_struct) {
TI_ERROR(
"Bit struct stores with type {} should have been "
"handled by BitStructStoreStmt.",
pointee_type->to_string());
} else {
TI_ERROR("Quant array only supports quant int type.");
}
if (stmt->dest->as<GetChStmt>()->input_snode->type == SNodeType::bit_struct) {
TI_ERROR(
"Bit struct stores with type {} should have been handled by BitStructStoreStmt.",
pointee_type->to_string());
}
if (auto qit = pointee_type->cast<QuantIntType>()) {
store_quant_int(llvm_val[stmt->dest], qit, llvm_val[stmt->val], true);
} else if (auto qfxt = pointee_type->cast<QuantFixedType>()) {
store_quant_fixed(llvm_val[stmt->dest], qfxt, llvm_val[stmt->val], true);
} else {
TI_NOT_IMPLEMENTED;
}
store_quant_int(llvm_val[stmt->dest], pointee_type->as<QuantIntType>(),
llvm_val[stmt->val], true);
} else {
builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]);
}
Expand All @@ -1417,8 +1416,8 @@ llvm::Value *CodeGenLLVM::create_intrinsic_load(const DataType &dtype,
TI_NOT_IMPLEMENTED;
}

void CodeGenLLVM::global_load(GlobalLoadStmt *stmt,
bool should_cache_as_read_only) {
void CodeGenLLVM::create_global_load(GlobalLoadStmt *stmt,
bool should_cache_as_read_only) {
auto ptr = llvm_val[stmt->src];
auto ptr_type = stmt->src->ret_type->as<PointerType>();
if (ptr_type->is_bit_pointer()) {
Expand Down Expand Up @@ -1449,7 +1448,7 @@ void CodeGenLLVM::global_load(GlobalLoadStmt *stmt,
}

void CodeGenLLVM::visit(GlobalLoadStmt *stmt) {
global_load(stmt, false);
create_global_load(stmt, false);
}

void CodeGenLLVM::visit(ElementShuffleStmt *stmt){
Expand Down
11 changes: 7 additions & 4 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Value *atomic_add_quant_int(AtomicOpStmt *stmt, QuantIntType *qit);

llvm::Value *quant_fixed_to_quant_int(QuantFixedType *qfxt,
QuantIntType *qit,
llvm::Value *real);
llvm::Value *to_quant_fixed(llvm::Value *real, QuantFixedType *qfxt);

virtual llvm::Value *optimized_reduction(AtomicOpStmt *stmt);

Expand All @@ -257,6 +255,11 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *value,
bool atomic);

void store_quant_fixed(llvm::Value *bit_ptr,
QuantFixedType *qfxt,
llvm::Value *value,
bool atomic);

void store_masked(llvm::Value *byte_ptr,
uint64 mask,
llvm::Value *value,
Expand Down Expand Up @@ -313,7 +316,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
QuantFloatType *qflt,
bool shared_exponent);

void global_load(GlobalLoadStmt *stmt, bool should_cache_as_read_only);
void create_global_load(GlobalLoadStmt *stmt, bool should_cache_as_read_only);

void visit(GlobalLoadStmt *stmt) override;

Expand Down
22 changes: 12 additions & 10 deletions taichi/codegen/llvm/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,18 @@ llvm::Value *CodeGenLLVM::atomic_add_quant_fixed(AtomicOpStmt *stmt,
auto [byte_ptr, bit_offset] = load_bit_ptr(llvm_val[stmt->dest]);
auto physical_type = byte_ptr->getType()->getPointerElementType();
auto qit = qfxt->get_digits_type()->as<QuantIntType>();
auto val_store = quant_fixed_to_quant_int(qfxt, qit, llvm_val[stmt->val]);
auto val_store = to_quant_fixed(llvm_val[stmt->val], qfxt);
val_store = builder->CreateSExt(val_store, physical_type);
return create_call(fmt::format("atomic_add_partial_bits_b{}",
physical_type->getIntegerBitWidth()),
{byte_ptr, bit_offset,
tlctx->get_constant(qit->get_num_bits()), val_store});
}

llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(QuantFixedType *qfxt,
QuantIntType *qit,
llvm::Value *real) {
llvm::Value *s = nullptr;

llvm::Value *CodeGenLLVM::to_quant_fixed(llvm::Value *real, QuantFixedType *qfxt) {
// Compute int(real * (1.0 / scale) + 0.5)
auto s_numeric = 1.0 / qfxt->get_scale();
auto compute_type = qfxt->get_compute_type();
s = builder->CreateFPCast(tlctx->get_constant(s_numeric),
llvm_type(compute_type));
auto s = builder->CreateFPCast(tlctx->get_constant(1.0 / qfxt->get_scale()), llvm_type(compute_type));
auto input_real = builder->CreateFPCast(real, llvm_type(compute_type));
auto scaled = builder->CreateFMul(input_real, s);

Expand All @@ -60,6 +54,7 @@ llvm::Value *CodeGenLLVM::quant_fixed_to_quant_int(QuantFixedType *qfxt,
fmt::format("rounding_prepare_f{}", data_type_bits(compute_type)),
{scaled});

auto qit = qfxt->get_digits_type()->as<QuantIntType>();
if (qit->get_is_signed()) {
return builder->CreateFPToSI(scaled, llvm_type(qit->get_compute_type()));
} else {
Expand All @@ -81,6 +76,13 @@ void CodeGenLLVM::store_quant_int(llvm::Value *bit_ptr,
builder->CreateIntCast(value, physical_type, false)});
}

void CodeGenLLVM::store_quant_fixed(llvm::Value *bit_ptr,
QuantFixedType *qfxt,
llvm::Value *value,
bool atomic) {
store_quant_int(bit_ptr, qfxt->get_digits_type()->as<QuantIntType>(), to_quant_fixed(value, qfxt), atomic);
}

void CodeGenLLVM::store_masked(llvm::Value *byte_ptr,
uint64 mask,
llvm::Value *value,
Expand Down Expand Up @@ -120,7 +122,7 @@ llvm::Value *CodeGenLLVM::quant_int_or_quant_fixed_to_bits(llvm::Value *val,
QuantIntType *qit = nullptr;
if (auto qfxt = input_type->cast<QuantFixedType>()) {
qit = qfxt->get_digits_type()->as<QuantIntType>();
val = quant_fixed_to_quant_int(qfxt, qit, val);
val = to_quant_fixed(val, qfxt);
} else {
qit = input_type->as<QuantIntType>();
}
Expand Down
10 changes: 7 additions & 3 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,13 @@ class QuantArrayType : public Type {
: physical_type_(physical_type),
element_type_(element_type_),
num_elements_(num_elements_) {
// TODO: avoid assertion?
TI_ASSERT(element_type_->is<QuantIntType>());
element_num_bits_ = element_type_->as<QuantIntType>()->get_num_bits();
if (auto qit = element_type_->cast<QuantIntType>()) {
element_num_bits_ = qit->get_num_bits();
} else if (auto qfxt = element_type_->cast<QuantFixedType>()) {
element_num_bits_ = qfxt->get_digits_type()->as<QuantIntType>()->get_num_bits();
} else {
TI_ERROR("Quant array only supports quant int/fixed type for now.");
}
}

std::string to_string() const override;
Expand Down
24 changes: 24 additions & 0 deletions tests/python/test_quant_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,30 @@ def assign():
assign()


@test_utils.test(require=ti.extension.quant, debug=True)
def test_1D_quant_array_fixed():
qfxt = ti.types.quant.fixed(frac=8, range=2)

x = ti.field(dtype=qfxt)

N = 4

ti.root.quant_array(ti.i, N, num_bits=32).place(x)

@ti.kernel
def set_val():
for i in range(N):
x[i] = i * 0.5

@ti.kernel
def verify_val():
for i in range(N):
assert x[i] == i * 0.5

set_val()
verify_val()


@test_utils.test(require=ti.extension.quant, debug=True)
def test_2D_quant_array():
qu1 = ti.types.quant.int(1, False)
Expand Down

0 comments on commit 9b4a451

Please sign in to comment.