Skip to content

Commit

Permalink
[type] [refactor] Promote DataType to a class (#1906)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu authored Oct 4, 2020
1 parent 16e6bc3 commit 7db99e0
Show file tree
Hide file tree
Showing 13 changed files with 255 additions and 172 deletions.
1 change: 1 addition & 0 deletions misc/prtags.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@
"error" : "Error messages",
"blender" : "Blender intergration",
"export" : "Exporting kernels",
"type" : "Type system",
"release" : "Release"
}
20 changes: 10 additions & 10 deletions python/taichi/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,32 @@ def is_taichi_class(rhs):

# Real types

float32 = taichi_lang_core.DataType.float32
float32 = taichi_lang_core.DataType_float32
f32 = float32
float64 = taichi_lang_core.DataType.float64
float64 = taichi_lang_core.DataType_float64
f64 = float64

real_types = [f32, f64, float]
real_type_ids = [id(t) for t in real_types]

# Integer types

int8 = taichi_lang_core.DataType.int8
int8 = taichi_lang_core.DataType_int8
i8 = int8
int16 = taichi_lang_core.DataType.int16
int16 = taichi_lang_core.DataType_int16
i16 = int16
int32 = taichi_lang_core.DataType.int32
int32 = taichi_lang_core.DataType_int32
i32 = int32
int64 = taichi_lang_core.DataType.int64
int64 = taichi_lang_core.DataType_int64
i64 = int64

uint8 = taichi_lang_core.DataType.uint8
uint8 = taichi_lang_core.DataType_uint8
u8 = uint8
uint16 = taichi_lang_core.DataType.uint16
uint16 = taichi_lang_core.DataType_uint16
u16 = uint16
uint32 = taichi_lang_core.DataType.uint32
uint32 = taichi_lang_core.DataType_uint32
u32 = uint32
uint64 = taichi_lang_core.DataType.uint64
uint64 = taichi_lang_core.DataType_uint64
u64 = uint64

integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int]
Expand Down
26 changes: 12 additions & 14 deletions taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,19 +257,17 @@ class CCTransformer : public IRVisitor {
}

static std::string _get_libc_function_name(std::string name, DataType dt) {
switch (dt) {
case DataType::i32:
return name;
case DataType::i64:
return "ll" + name;
case DataType::f32:
return name + "f";
case DataType::f64:
return name;
default:
TI_ERROR("Unsupported function \"{}\" for DataType={} on C backend",
name, data_type_name(dt));
}
if (dt == DataType::i32)
return name;
else if (dt == DataType::i64)
return "ll" + name;
else if (dt == DataType::f32)
return name + "f";
else if (dt == DataType::f64)
return name;
else
TI_ERROR("Unsupported function \"{}\" for DataType={} on C backend", name,
data_type_name(dt));
}

static std::string get_libc_function_name(std::string name, DataType dt) {
Expand Down Expand Up @@ -598,7 +596,7 @@ class CCTransformer : public IRVisitor {
void emit_header(std::string f, Args &&... args) {
line_appender_header.append(std::move(f), std::move(args)...);
}
};
}; // namespace cccp

std::unique_ptr<CCKernel> CCKernelGen::compile() {
auto program = kernel->program.cc_program.get();
Expand Down
38 changes: 17 additions & 21 deletions taichi/backends/metal/data_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,24 @@ TLANG_NAMESPACE_BEGIN
namespace metal {

MetalDataType to_metal_type(DataType dt) {
switch (dt) {
#define METAL_CASE(x) \
case DataType::x: \
return MetalDataType::x

METAL_CASE(f32);
METAL_CASE(f64);
METAL_CASE(i8);
METAL_CASE(i16);
METAL_CASE(i32);
METAL_CASE(i64);
METAL_CASE(u8);
METAL_CASE(u16);
METAL_CASE(u32);
METAL_CASE(u64);
METAL_CASE(unknown);
#undef METAL_CASE

default:
TI_NOT_IMPLEMENTED;
break;
#define METAL_CASE(x) else if (dt == DataType::x) return MetalDataType::x
if (false) {
}
METAL_CASE(f32);
METAL_CASE(f64);
METAL_CASE(i8);
METAL_CASE(i16);
METAL_CASE(i32);
METAL_CASE(i64);
METAL_CASE(u8);
METAL_CASE(u16);
METAL_CASE(u32);
METAL_CASE(u64);
METAL_CASE(unknown);
else {
TI_NOT_IMPLEMENTED;
}
#undef METAL_CASE
return MetalDataType::unknown;
}

Expand Down
39 changes: 16 additions & 23 deletions taichi/backends/opengl/opengl_data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@ namespace opengl {

inline std::string opengl_data_type_name(DataType dt) {
// https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)
switch (dt) {
case DataType::f32:
return "float";
case DataType::f64:
return "double";
case DataType::i32:
return "int";
case DataType::i64:
return "int64_t";
default:
TI_NOT_IMPLEMENTED;
break;
}
return "";
if (dt == DataType::f32)
return "float";
else if (dt == DataType::f64)
return "double";
else if (dt == DataType::i32)
return "int";
else if (dt == DataType::i64)
return "int64_t";
else
TI_NOT_IMPLEMENTED;
}

inline bool is_opengl_binary_op_infix(BinaryOpType type) {
Expand All @@ -36,15 +32,12 @@ inline bool is_opengl_binary_op_different_return_type(BinaryOpType type) {
}

inline int opengl_data_address_shifter(DataType type) {
switch (type) {
case DataType::f32:
case DataType::i32:
return 2;
case DataType::f64:
case DataType::i64:
return 3;
default:
TI_NOT_IMPLEMENTED
if (type == DataType::f32 || type == DataType::i32)
return 2;
else if (type == DataType::f64 || type == DataType::i64) {
return 3;
} else {
TI_NOT_IMPLEMENTED
}
}

Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
dest_ty = tlctx->get_data_type(stmt->ret_type.data_type);
auto dest_bits = dest_ty->getPrimitiveSizeInBits();
auto truncated = builder->CreateTrunc(
raw_arg, Type::getIntNTy(*llvm_context, dest_bits));
raw_arg, llvm::Type::getIntNTy(*llvm_context, dest_bits));
llvm_val[stmt] = builder->CreateBitCast(truncated, dest_ty);
}
}
Expand Down Expand Up @@ -1327,7 +1327,7 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) {

// per-leaf-block for loop
auto loop_index =
create_entry_block_alloca(Type::getInt32Ty(*llvm_context));
create_entry_block_alloca(llvm::Type::getInt32Ty(*llvm_context));

llvm::Value *thread_idx = nullptr, *block_dim = nullptr;

Expand Down
Loading

0 comments on commit 7db99e0

Please sign in to comment.