Skip to content

Commit

Permalink
[type] Support exponents in CustomFloatType (#2122)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu authored Dec 27, 2020
1 parent 0f81998 commit f1768ca
Show file tree
Hide file tree
Showing 18 changed files with 415 additions and 59 deletions.
6 changes: 5 additions & 1 deletion python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .matrix import Matrix, Vector
from .transformer import TaichiSyntaxError
from .ndrange import ndrange, GroupedNDRange
from .type_factory import TypeFactory
from copy import deepcopy as _deepcopy
import functools
import os
Expand Down Expand Up @@ -46,9 +47,12 @@
kernel_profiler_total_time = lambda: get_runtime(
).prog.kernel_profiler_total_time()

# Unstable API
# Legacy API
type_factory_ = core.get_type_factory_instance()

# Unstable API
type_factory = TypeFactory()


def memory_profiler_print():
get_runtime().materialize()
Expand Down
5 changes: 4 additions & 1 deletion python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,17 @@ def _bit_array(self, indices, dimensions, num_bits):
dimensions = [dimensions] * len(indices)
return SNode(self.ptr.bit_array(indices, dimensions, num_bits))

def place(self, *args, offset=None):
def place(self, *args, offset=None, shared_exponent=False):
from .expr import Expr
from .util import is_taichi_class
if offset is None:
offset = []
if isinstance(offset, numbers.Number):
offset = (offset, )
for arg in args:
assert shared_exponent == False
# TODO: implement shared exponent

if isinstance(arg, Expr):
self.ptr.place(Expr(arg).ptr, offset)
elif isinstance(arg, list):
Expand Down
20 changes: 20 additions & 0 deletions python/taichi/lang/type_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
class TypeFactory:
def __init__(self):
from taichi.core import ti_core
self.core = ti_core.get_type_factory_instance()

def custom_int(self, bits, signed=True):
return self.core.get_custom_int_type(bits, signed)

def custom_float(self,
significand_type,
exponent_type=None,
compute_type=None,
scale=1.0):
import taichi as ti
if compute_type is None:
compute_type = ti.get_runtime().default_fp.get_ptr()
return self.core.get_custom_float_type(significand_type,
exponent_type,
compute_type,
scale=scale)
203 changes: 172 additions & 31 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,24 @@ void CodeGenLLVM::visit(GlobalPtrStmt *stmt) {
TI_ERROR("Global Ptrs should have been lowered.");
}

void CodeGenLLVM::store_custom_int(llvm::Value *bit_ptr,
CustomIntType *cit,
llvm::Value *value) {
llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr;
read_bit_pointer(bit_ptr, byte_ptr, bit_offset);
// TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers.
// Try to support CustomInt/FloatType with 8/16-bit physical
// types.

create_call(fmt::format("set_partial_bits_b{}",
data_type_bits(cit->get_physical_type())),
{builder->CreateBitCast(byte_ptr,
llvm_ptr_type(cit->get_physical_type())),
bit_offset, tlctx->get_constant(cit->get_num_bits()),
builder->CreateIntCast(
value, llvm_type(cit->get_physical_type()), false)});
}

void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
TI_ASSERT(!stmt->parent->mask() || stmt->width() == 1);
TI_ASSERT(llvm_val[stmt->data]);
Expand All @@ -1185,33 +1203,75 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
cit = cit_;
store_value = llvm_val[stmt->data];
} else if (auto cft = pointee_type->cast<CustomFloatType>()) {
cit = cft->get_digits_type()->as<CustomIntType>();
store_value = float_to_custom_int(cft, cit, llvm_val[stmt->data]);
llvm::Value *digit_bits = nullptr;
auto digits_cit = cft->get_digits_type()->as<CustomIntType>();
if (auto exp = cft->get_exponent_type()) {
// Extract exponent and digits from compute type (assumed to be f32 for
// now).
TI_ASSERT(cft->get_compute_type()->is_primitive(PrimitiveTypeID::f32));

auto f32_bits = builder->CreateBitCast(
llvm_val[stmt->data], llvm::Type::getInt32Ty(*llvm_context));
auto exponent_bits = builder->CreateAShr(f32_bits, 23);
exponent_bits = builder->CreateAnd(exponent_bits,
tlctx->get_constant((1 << 8) - 1));
// f32 = 1 sign bit + 8 exponent bits + 23 fraction bits
auto value_bits = builder->CreateAShr(
f32_bits, tlctx->get_constant(23 - cft->get_digit_bits()));

digit_bits = builder->CreateAnd(
value_bits,
tlctx->get_constant((1 << (cft->get_digit_bits())) - 1));

if (cft->get_is_signed()) {
// extract the sign bit
auto sign_bit =
builder->CreateAnd(f32_bits, tlctx->get_constant(0x80000000u));
// insert the sign bit to digit bits
digit_bits = builder->CreateOr(
digit_bits,
builder->CreateLShr(sign_bit, 31 - cft->get_digit_bits()));
}

auto exponent_cit = exp->as<CustomIntType>();

auto digits_snode = stmt->ptr->as<GetChStmt>()->output_snode;
auto exponent_snode = digits_snode->exp_snode;

// Since we have fewer bits in the exponent type than in f32, an
// offset is necessary to make sure the stored exponent values are
// representable by the exponent custom int type.
exponent_bits = builder->CreateSub(
exponent_bits,
tlctx->get_constant(cft->get_exponent_conversion_offset()));

// Compute the bit pointer of the exponent bits.
TI_ASSERT(digits_snode->parent == exponent_snode->parent);
auto exponent_bit_ptr =
offset_bit_ptr(llvm_val[stmt->ptr], exponent_snode->bit_offset -
digits_snode->bit_offset);
store_custom_int(exponent_bit_ptr, exponent_cit, exponent_bits);
store_value = digit_bits;
} else {
digit_bits = llvm_val[stmt->data];
store_value = float_to_custom_int(cft, digits_cit, digit_bits);
}
cit = digits_cit;
} else {
TI_NOT_IMPLEMENTED
}
llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr;
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset);
// TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers.
// Try to support CustomInt/FloatType with 8/16-bit physical
// types.
create_call(fmt::format("set_partial_bits_b{}",
data_type_bits(cit->get_physical_type())),
{builder->CreateBitCast(
byte_ptr, llvm_ptr_type(cit->get_physical_type())),
bit_offset, tlctx->get_constant(cit->get_num_bits()),
builder->CreateIntCast(
store_value, llvm_type(cit->get_physical_type()), false)});
store_custom_int(llvm_val[stmt->ptr], cit, store_value);
} else {
builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]);
}
}

llvm::Value *CodeGenLLVM::load_as_custom_int(Stmt *ptr, Type *load_type) {
llvm::Value *CodeGenLLVM::load_as_custom_int(llvm::Value *ptr,
Type *load_type) {
auto *cit = load_type->as<CustomIntType>();
// load bit pointer
llvm::Value *byte_ptr, *bit_offset;
read_bit_pointer(llvm_val[ptr], byte_ptr, bit_offset);
read_bit_pointer(ptr, byte_ptr, bit_offset);

auto bit_level_container = builder->CreateLoad(builder->CreateBitCast(
byte_ptr, llvm_ptr_type(cit->get_physical_type())));
Expand Down Expand Up @@ -1264,17 +1324,82 @@ llvm::Value *CodeGenLLVM::reconstruct_custom_float(llvm::Value *digits,
return builder->CreateFMul(cast, s);
}

llvm::Value *CodeGenLLVM::load_custom_float_with_exponent(
llvm::Value *digits_bit_ptr,
llvm::Value *exponent_bit_ptr,
CustomFloatType *cft) {
// TODO: we ignore "scale" for CustomFloatType with exponent for now. Fix
// this.
TI_ASSERT(cft->get_scale() == 1);
auto digits = load_as_custom_int(digits_bit_ptr, cft->get_digits_type());

auto exponent_val = load_as_custom_int(
exponent_bit_ptr, cft->get_exponent_type()->as<CustomIntType>());

// Make sure the exponent is within the range of the exponent type
exponent_val = builder->CreateAdd(
exponent_val, tlctx->get_constant(cft->get_exponent_conversion_offset()));

if (cft->get_compute_type()->is_primitive(PrimitiveTypeID::f32)) {
// Construct an f32 out of exponent_val and digits
// Assuming digits and exponent_val are i32
// f32 = 1 sign bit + 8 exponent bits + 23 fraction bits
auto exponent_bits =
builder->CreateShl(exponent_val, tlctx->get_constant(23));

digits = builder->CreateAnd(
digits,
(1u << cft->get_digits_type()->as<CustomIntType>()->get_num_bits()) -
1);
digits = builder->CreateShl(
digits, tlctx->get_constant(23 - cft->get_digit_bits()));

auto fraction_bits = builder->CreateAnd(digits, (1u << 23) - 1);

auto f32_bits = builder->CreateOr(exponent_bits, fraction_bits);

if (cft->get_is_signed()) {
auto sign_bit =
builder->CreateAnd(digits, tlctx->get_constant(1u << (23)));

sign_bit = builder->CreateShl(sign_bit, tlctx->get_constant(31 - (23)));
f32_bits = builder->CreateOr(f32_bits, sign_bit);
}

return builder->CreateBitCast(f32_bits,
llvm::Type::getFloatTy(*llvm_context));
} else {
TI_NOT_IMPLEMENTED;
}
}

void CodeGenLLVM::visit(GlobalLoadStmt *stmt) {
int width = stmt->width();
TI_ASSERT(width == 1);
auto ptr_type = stmt->ptr->ret_type->as<PointerType>();
if (ptr_type->is_bit_pointer()) {
auto val_type = ptr_type->get_pointee_type();
if (val_type->is<CustomIntType>()) {
llvm_val[stmt] = load_as_custom_int(stmt->ptr, val_type);
llvm_val[stmt] = load_as_custom_int(llvm_val[stmt->ptr], val_type);
} else if (auto cft = val_type->cast<CustomFloatType>()) {
auto digits = load_as_custom_int(stmt->ptr, cft->get_digits_type());
llvm_val[stmt] = reconstruct_custom_float(digits, val_type);
if (cft->get_exponent_type()) {
auto ptr = stmt->ptr->as<GetChStmt>();
TI_ASSERT(ptr->width() == 1);
auto digits_bit_ptr = llvm_val[ptr];
auto digits_snode = ptr->output_snode;
auto exponent_snode = digits_snode->exp_snode;
// Compute the bit pointer of the exponent bits.
TI_ASSERT(digits_snode->parent == exponent_snode->parent);
auto exponent_bit_ptr =
offset_bit_ptr(digits_bit_ptr, exponent_snode->bit_offset -
digits_snode->bit_offset);
llvm_val[stmt] = load_custom_float_with_exponent(digits_bit_ptr,
exponent_bit_ptr, cft);
} else {
auto digits =
load_as_custom_int(llvm_val[stmt->ptr], cft->get_digits_type());
llvm_val[stmt] = reconstruct_custom_float(digits, val_type);
}
} else {
TI_NOT_IMPLEMENTED
}
Expand Down Expand Up @@ -1374,7 +1499,7 @@ void CodeGenLLVM::visit(IntegerOffsetStmt *stmt){TI_NOT_IMPLEMENTED}

llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base,
llvm::Value *bit_offset) {
// 1. create a bit pointer struct
// 1. get the bit pointer LLVM struct
// struct bit_pointer {
// i8* byte_ptr;
// i32 offset;
Expand All @@ -1383,21 +1508,37 @@ llvm::Value *CodeGenLLVM::create_bit_ptr_struct(llvm::Value *byte_ptr_base,
*llvm_context, {llvm::Type::getInt8PtrTy(*llvm_context),
llvm::Type::getInt32Ty(*llvm_context),
llvm::Type::getInt32Ty(*llvm_context)});
// 2. alloca the bit pointer struct
// 2. allocate the bit pointer struct
auto bit_ptr_struct = create_entry_block_alloca(struct_type);
// 3. store `input_ptr` into `bit_ptr_struct`
auto byte_ptr = builder->CreateBitCast(
byte_ptr_base, llvm::PointerType::getInt8PtrTy(*llvm_context));
builder->CreateStore(
byte_ptr, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0),
tlctx->get_constant(0)}));
// 4. store `offset` in `bit_ptr_struct`
builder->CreateStore(
bit_offset, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0),
tlctx->get_constant(1)}));
// 3. store `byte_ptr_base` into `bit_ptr_struct` (if provided)
if (byte_ptr_base) {
auto byte_ptr = builder->CreateBitCast(
byte_ptr_base, llvm::PointerType::getInt8PtrTy(*llvm_context));
builder->CreateStore(
byte_ptr, builder->CreateGEP(bit_ptr_struct, {tlctx->get_constant(0),
tlctx->get_constant(0)}));
}
// 4. store `offset` in `bit_ptr_struct` (if provided)
if (bit_offset) {
builder->CreateStore(
bit_offset,
builder->CreateGEP(bit_ptr_struct,
{tlctx->get_constant(0), tlctx->get_constant(1)}));
}
return bit_ptr_struct;
}

llvm::Value *CodeGenLLVM::offset_bit_ptr(llvm::Value *input_bit_ptr,
int bit_offset_delta) {
auto byte_ptr_base = builder->CreateLoad(builder->CreateGEP(
input_bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(0)}));
auto input_offset = builder->CreateLoad(builder->CreateGEP(
input_bit_ptr, {tlctx->get_constant(0), tlctx->get_constant(1)}));
auto new_bit_offset =
builder->CreateAdd(input_offset, tlctx->get_constant(bit_offset_delta));
return create_bit_ptr_struct(byte_ptr_base, new_bit_offset);
}

void CodeGenLLVM::visit(SNodeLookupStmt *stmt) {
llvm::Value *parent = nullptr;
parent = llvm_val[stmt->input_snode];
Expand Down
16 changes: 13 additions & 3 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,24 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(GlobalPtrStmt *stmt) override;

void store_custom_int(llvm::Value *bit_ptr,
CustomIntType *cit,
llvm::Value *value);

void visit(GlobalStoreStmt *stmt) override;

llvm::Value *load_as_custom_int(Stmt *ptr, Type *load_type);
llvm::Value *load_as_custom_int(llvm::Value *ptr, Type *load_type);

llvm::Value *extract_custom_int(llvm::Value *physical_value,
llvm::Value *bit_offset,
Type *load_type);

llvm::Value *reconstruct_custom_float(llvm::Value *digits, Type *load_type);

llvm::Value *load_custom_float_with_exponent(llvm::Value *digits_bit_ptr,
llvm::Value *exponent_bit_ptr,
CustomFloatType *cft);

void visit(GlobalLoadStmt *stmt) override;

void visit(ElementShuffleStmt *stmt) override;
Expand All @@ -227,8 +235,10 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(IntegerOffsetStmt *stmt) override;

llvm::Value *create_bit_ptr_struct(llvm::Value *byte_ptr_base,
llvm::Value *bit_offset);
llvm::Value *create_bit_ptr_struct(llvm::Value *byte_ptr_base = nullptr,
llvm::Value *bit_offset = nullptr);

llvm::Value *offset_bit_ptr(llvm::Value *input_bit_ptr, int bit_offset_delta);

void visit(SNodeLookupStmt *stmt) override;

Expand Down
Loading

0 comments on commit f1768ca

Please sign in to comment.