Skip to content

Commit

Permalink
[refactor] Use PrimitiveType::type instead of DataType::type (#1926)
Browse files Browse the repository at this point in the history
  • Loading branch information
taichi-gardener authored Oct 6, 2020
1 parent 094baad commit 8a5035d
Show file tree
Hide file tree
Showing 34 changed files with 361 additions and 344 deletions.
2 changes: 1 addition & 1 deletion docs/gui.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Taichi's GUI supports painting simple geometric objects, such as lines, triangle

.. note::

The position parameter ``pos`` expects an input of a 2-element tuple, whose values are the relative position of the object.
The position parameter ``pos`` expects an input of a 2-element tuple, whose values are the relative position of the object.
``(0.0, 0.0)`` stands for the lower left corner of the window, and ``(1.0, 1.0)`` stands for the upper right corner.


Expand Down
4 changes: 2 additions & 2 deletions taichi/analysis/value_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ValueDiffLoopIndex : public IRVisitor {
}

void visit(ConstStmt *stmt) override {
if (stmt->val[lane].dt == DataType::i32) {
if (stmt->val[lane].dt == PrimitiveType::i32) {
results[stmt->instance_id] = DiffRange(true, 0, stmt->val[lane].val_i32);
} else {
results[stmt->instance_id] = DiffRange();
Expand Down Expand Up @@ -112,7 +112,7 @@ class FindDirectValueBaseAndOffset : public IRVisitor {

void visit(ConstStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
if (stmt->val[0].dt == DataType::i32) {
if (stmt->val[0].dt == PrimitiveType::i32) {
result = std::make_tuple(true, nullptr, stmt->val[0].val_i32);
}
}
Expand Down
8 changes: 4 additions & 4 deletions taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,13 @@ class CCTransformer : public IRVisitor {
}

static std::string _get_libc_function_name(std::string name, DataType dt) {
if (dt == DataType::i32)
if (dt == PrimitiveType::i32)
return name;
else if (dt == DataType::i64)
else if (dt == PrimitiveType::i64)
return "ll" + name;
else if (dt == DataType::f32)
else if (dt == PrimitiveType::f32)
return name + "f";
else if (dt == DataType::f64)
else if (dt == PrimitiveType::f64)
return name;
else
TI_ERROR("Unsupported function \"{}\" for DataType={} on C backend", name,
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class CodeGenLLVMCPU : public CodeGenLLVM {
llvm::Type::getInt8PtrTy(*llvm_context),
tlctx->get_data_type<int>()});

auto loop_var = create_entry_block_alloca(DataType::i32);
auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
loop_vars_llvm[stmt].push_back(loop_var);
builder->CreateStore(get_arg(2), loop_var);
stmt->body->accept(this);
Expand Down
36 changes: 18 additions & 18 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {

auto value_type = tlctx->get_data_type(arg_stmt->ret_type.data_type);
auto value = llvm_val[arg_stmt];
if (arg_stmt->ret_type.data_type == DataType::f32) {
value_type = tlctx->get_data_type(DataType::f64);
if (arg_stmt->ret_type.data_type == PrimitiveType::f32) {
value_type = tlctx->get_data_type(PrimitiveType::f64);
value = builder->CreateFPExt(value, value_type);
}

Expand Down Expand Up @@ -162,43 +162,43 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {

#define UNARY_STD(x) \
else if (op == UnaryOpType::x) { \
if (input_taichi_type == DataType::f32) { \
if (input_taichi_type == PrimitiveType::f32) { \
llvm_val[stmt] = \
builder->CreateCall(get_runtime_function("__nv_" #x "f"), input); \
} else if (input_taichi_type == DataType::f64) { \
} else if (input_taichi_type == PrimitiveType::f64) { \
llvm_val[stmt] = \
builder->CreateCall(get_runtime_function("__nv_" #x), input); \
} else if (input_taichi_type == DataType::i32) { \
} else if (input_taichi_type == PrimitiveType::i32) { \
llvm_val[stmt] = builder->CreateCall(get_runtime_function(#x), input); \
} else { \
TI_NOT_IMPLEMENTED \
} \
}
if (op == UnaryOpType::abs) {
if (input_taichi_type == DataType::f32) {
if (input_taichi_type == PrimitiveType::f32) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_fabsf"), input);
} else if (input_taichi_type == DataType::f64) {
} else if (input_taichi_type == PrimitiveType::f64) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_fabs"), input);
} else if (input_taichi_type == DataType::i32) {
} else if (input_taichi_type == PrimitiveType::i32) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_abs"), input);
} else {
TI_NOT_IMPLEMENTED
}
} else if (op == UnaryOpType::sqrt) {
if (input_taichi_type == DataType::f32) {
if (input_taichi_type == PrimitiveType::f32) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_sqrtf"), input);
} else if (input_taichi_type == DataType::f64) {
} else if (input_taichi_type == PrimitiveType::f64) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("__nv_sqrt"), input);
} else {
TI_NOT_IMPLEMENTED
}
} else if (op == UnaryOpType::logic_not) {
if (input_taichi_type == DataType::i32) {
if (input_taichi_type == PrimitiveType::i32) {
llvm_val[stmt] =
builder->CreateCall(get_runtime_function("logic_not_i32"), input);
} else {
Expand Down Expand Up @@ -236,11 +236,11 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
llvm::AtomicRMWInst::BinOp::Add, llvm_val[stmt->dest],
llvm_val[stmt->val],
llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == DataType::f32) {
} else if (stmt->val->ret_type.data_type == PrimitiveType::f32) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::FAdd, llvm_val[stmt->dest],
llvm_val[stmt->val], AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == DataType::f64) {
} else if (stmt->val->ret_type.data_type == PrimitiveType::f64) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::FAdd, llvm_val[stmt->dest],
llvm_val[stmt->val], AtomicOrdering::SequentiallyConsistent);
Expand All @@ -253,11 +253,11 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
llvm::AtomicRMWInst::BinOp::Min, llvm_val[stmt->dest],
llvm_val[stmt->val],
llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == DataType::f32) {
} else if (stmt->val->ret_type.data_type == PrimitiveType::f32) {
old_value =
builder->CreateCall(get_runtime_function("atomic_min_f32"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type.data_type == DataType::f64) {
} else if (stmt->val->ret_type.data_type == PrimitiveType::f64) {
old_value =
builder->CreateCall(get_runtime_function("atomic_min_f64"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
Expand All @@ -270,11 +270,11 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
llvm::AtomicRMWInst::BinOp::Max, llvm_val[stmt->dest],
llvm_val[stmt->val],
llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == DataType::f32) {
} else if (stmt->val->ret_type.data_type == PrimitiveType::f32) {
old_value =
builder->CreateCall(get_runtime_function("atomic_max_f32"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type.data_type == DataType::f64) {
} else if (stmt->val->ret_type.data_type == PrimitiveType::f64) {
old_value =
builder->CreateCall(get_runtime_function("atomic_max_f64"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
Expand Down Expand Up @@ -334,7 +334,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
{llvm::PointerType::get(get_runtime_type("Context"), 0),
get_tls_buffer_type(), tlctx->get_data_type<int>()});

auto loop_var = create_entry_block_alloca(DataType::i32);
auto loop_var = create_entry_block_alloca(PrimitiveType::i32);
loop_vars_llvm[stmt].push_back(loop_var);
builder->CreateStore(get_arg(2), loop_var);
stmt->body->accept(this);
Expand Down
13 changes: 7 additions & 6 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class KernelCodegen : public IRVisitor {
}
} else if (opty == SNodeOpType::append) {
TI_ASSERT(is_dynamic);
TI_ASSERT(stmt->ret_type.data_type == DataType::i32);
TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32);
emit("{} = {}.append({});", result_var, parent, stmt->val->raw_name());
} else if (opty == SNodeOpType::length) {
TI_ASSERT(is_dynamic);
Expand Down Expand Up @@ -485,19 +485,19 @@ class KernelCodegen : public IRVisitor {
current_appender().push_indent();
}

if (dt == DataType::i32) {
if (dt == PrimitiveType::i32) {
emit(
"const auto {} = atomic_fetch_{}_explicit((device atomic_int*){}, "
"{}, "
"metal::memory_order_relaxed);",
stmt->raw_name(), op_name, stmt->dest->raw_name(), val_var);
} else if (dt == DataType::u32) {
} else if (dt == PrimitiveType::u32) {
emit(
"const auto {} = atomic_fetch_{}_explicit((device atomic_uint*){}, "
"{}, "
"metal::memory_order_relaxed);",
stmt->raw_name(), op_name, stmt->dest->raw_name(), val_var);
} else if (dt == DataType::f32) {
} else if (dt == PrimitiveType::f32) {
if (handle_float) {
emit("const float {} = fatomic_fetch_{}({}, {});", stmt->raw_name(),
op_name, stmt->dest->raw_name(), val_var);
Expand Down Expand Up @@ -624,7 +624,7 @@ class KernelCodegen : public IRVisitor {
if (std::holds_alternative<Stmt *>(entry)) {
auto *arg_stmt = std::get<Stmt *>(entry);
const auto dt = arg_stmt->element_type();
TI_ASSERT_INFO(dt == DataType::i32 || dt == DataType::f32,
TI_ASSERT_INFO(dt == PrimitiveType::i32 || dt == PrimitiveType::f32,
"print() only supports i32 or f32 scalars for now.");
emit("{}.pm_set_{}({}, {});", msg_var_name, data_type_short_name(dt),
i, arg_stmt->raw_name());
Expand Down Expand Up @@ -1037,7 +1037,8 @@ class KernelCodegen : public IRVisitor {
used_features()->sparse = true;
}

std::string inject_load_global_tmp(int offset, DataType dt = DataType::i32) {
std::string inject_load_global_tmp(int offset,
DataType dt = PrimitiveType::i32) {
const auto vt = VectorType(/*width=*/1, dt);
auto gtmp = Stmt::make<GlobalTemporaryStmt>(offset, vt);
gtmp->accept(this);
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/data_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ TLANG_NAMESPACE_BEGIN
namespace metal {

MetalDataType to_metal_type(DataType dt) {
#define METAL_CASE(x) else if (dt == DataType::x) return MetalDataType::x
#define METAL_CASE(x) else if (dt == PrimitiveType::x) return MetalDataType::x
if (false) {
}
METAL_CASE(f32);
Expand Down
36 changes: 18 additions & 18 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,22 @@ class KernelGen : public IRVisitor {
// Note that the following two functions not only returns the corresponding
// data type, but also **records** the usage of `i64` and `f64`.
std::string opengl_data_type_short_name(DataType dt) {
if (dt == DataType::i64) {
if (dt == PrimitiveType::i64) {
if (!TI_OPENGL_REQUIRE(used, GL_ARB_gpu_shader_int64)) {
TI_ERROR(
"Extension GL_ARB_gpu_shader_int64 not supported on your OpenGL");
}
used.int64 = true;
}
if (dt == DataType::f64)
if (dt == PrimitiveType::f64)
used.float64 = true;
return data_type_short_name(dt);
}

std::string opengl_data_type_name(DataType dt) {
if (dt == DataType::i64)
if (dt == PrimitiveType::i64)
used.int64 = true;
if (dt == DataType::f64)
if (dt == PrimitiveType::f64)
used.float64 = true;
return opengl::opengl_data_type_name(dt);
}
Expand Down Expand Up @@ -360,7 +360,7 @@ class KernelGen : public IRVisitor {
}

} else if (stmt->op_type == SNodeOpType::is_active) {
TI_ASSERT(stmt->ret_type.data_type == DataType::i32);
TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32);
if (stmt->snode->type == SNodeType::dense ||
stmt->snode->type == SNodeType::root) {
emit("int {} = 1;", stmt->short_name());
Expand All @@ -373,7 +373,7 @@ class KernelGen : public IRVisitor {

} else if (stmt->op_type == SNodeOpType::append) {
TI_ASSERT(stmt->snode->type == SNodeType::dynamic);
TI_ASSERT(stmt->ret_type.data_type == DataType::i32);
TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32);
emit("int {} = atomicAdd(_data_i32_[{} >> 2], 1);", stmt->short_name(),
get_snode_meta_address(stmt->snode));
auto dt = stmt->val->element_type();
Expand All @@ -387,7 +387,7 @@ class KernelGen : public IRVisitor {

} else if (stmt->op_type == SNodeOpType::length) {
TI_ASSERT(stmt->snode->type == SNodeType::dynamic);
TI_ASSERT(stmt->ret_type.data_type == DataType::i32);
TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32);
emit("int {} = _data_i32_[{} >> 2];", stmt->short_name(),
get_snode_meta_address(stmt->snode));

Expand Down Expand Up @@ -479,12 +479,12 @@ class KernelGen : public IRVisitor {
emit("{} {} = {}({});", dt_name, stmt->short_name(),
opengl_data_type_name(stmt->cast_type), stmt->operand->short_name());
} else if (stmt->op_type == UnaryOpType::cast_bits) {
if (stmt->cast_type == DataType::f32 &&
stmt->operand->element_type() == DataType::i32) {
if (stmt->cast_type == PrimitiveType::f32 &&
stmt->operand->element_type() == PrimitiveType::i32) {
emit("{} {} = intBitsToFloat({});", dt_name, stmt->short_name(),
stmt->operand->short_name());
} else if (stmt->cast_type == DataType::i32 &&
stmt->operand->element_type() == DataType::f32) {
} else if (stmt->cast_type == PrimitiveType::i32 &&
stmt->operand->element_type() == PrimitiveType::f32) {
emit("{} {} = floatBitsToInt({});", dt_name, stmt->short_name(),
stmt->operand->short_name());
} else {
Expand Down Expand Up @@ -527,7 +527,7 @@ class KernelGen : public IRVisitor {
return;
} else if (bin->op_type == BinaryOpType::atan2) {
if (bin->element_type() ==
DataType::f64) { // don't know why no atan(double, double)
PrimitiveType::f64) { // don't know why no atan(double, double)
emit("{} {} = {}(atan(float({}), float({})));", dt_name, bin_name,
dt_name, lhs_name, rhs_name);
} else {
Expand Down Expand Up @@ -573,25 +573,25 @@ class KernelGen : public IRVisitor {
void visit(AtomicOpStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
auto dt = stmt->dest->element_type();
if (dt == DataType::i32 ||
if (dt == PrimitiveType::i32 ||
(TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_int64) &&
dt == DataType::i64) ||
dt == PrimitiveType::i64) ||
((stmt->op_type == AtomicOpType::add ||
stmt->op_type == AtomicOpType::sub) &&
((TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_float) &&
dt == DataType::f32) ||
dt == PrimitiveType::f32) ||
(TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_float64) &&
dt == DataType::f64)))) {
dt == PrimitiveType::f64)))) {
emit("{} {} = {}(_{}_{}_[{} >> {}], {});",
opengl_data_type_name(stmt->val->element_type()), stmt->short_name(),
opengl_atomic_op_type_cap_name(stmt->op_type),
ptr_signats.at(stmt->dest->id), opengl_data_type_short_name(dt),
stmt->dest->short_name(), opengl_data_address_shifter(dt),
stmt->val->short_name());
} else {
if (dt != DataType::f32) {
if (dt != PrimitiveType::f32) {
TI_ERROR(
"unsupported atomic operation for DataType::{}, "
"unsupported atomic operation for PrimitiveType::{}, "
"this may because your OpenGL is missing that extension, "
"see `glewinfo` for more details",
opengl_data_type_short_name(dt));
Expand Down
12 changes: 6 additions & 6 deletions taichi/backends/opengl/opengl_data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ namespace opengl {

inline std::string opengl_data_type_name(DataType dt) {
// https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)
if (dt == DataType::f32)
if (dt == PrimitiveType::f32)
return "float";
else if (dt == DataType::f64)
else if (dt == PrimitiveType::f64)
return "double";
else if (dt == DataType::i32)
else if (dt == PrimitiveType::i32)
return "int";
else if (dt == DataType::i64)
else if (dt == PrimitiveType::i64)
return "int64_t";
else
TI_NOT_IMPLEMENTED;
Expand All @@ -32,9 +32,9 @@ inline bool is_opengl_binary_op_different_return_type(BinaryOpType type) {
}

inline int opengl_data_address_shifter(DataType type) {
if (type == DataType::f32 || type == DataType::i32)
if (type == PrimitiveType::f32 || type == PrimitiveType::i32)
return 2;
else if (type == DataType::f64 || type == DataType::i64) {
else if (type == PrimitiveType::f64 || type == PrimitiveType::i64) {
return 3;
} else {
TI_NOT_IMPLEMENTED
Expand Down
Loading

0 comments on commit 8a5035d

Please sign in to comment.