From 8a5035de840f272dc115d7954e10b6c11b340938 Mon Sep 17 00:00:00 2001 From: Taichi Gardener <62079278+taichi-gardener@users.noreply.github.com> Date: Tue, 6 Oct 2020 01:00:18 -0400 Subject: [PATCH] [refactor] Use PrimitiveType::type instead of DataType::type (#1926) --- docs/gui.rst | 2 +- taichi/analysis/value_diff.cpp | 4 +- taichi/backends/cc/codegen_cc.cpp | 8 +- taichi/backends/cpu/codegen_cpu.cpp | 2 +- taichi/backends/cuda/codegen_cuda.cpp | 36 +++---- taichi/backends/metal/codegen_metal.cpp | 13 +-- taichi/backends/metal/data_types.cpp | 2 +- taichi/backends/opengl/codegen_opengl.cpp | 36 +++---- taichi/backends/opengl/opengl_data_types.h | 12 +-- taichi/codegen/codegen_llvm.cpp | 91 +++++++++--------- taichi/ir/expr.cpp | 3 +- taichi/ir/frontend.h | 13 +-- taichi/ir/frontend_ir.h | 2 +- taichi/ir/ir.cpp | 2 +- taichi/ir/ir.h | 3 +- taichi/ir/snode.cpp | 2 +- taichi/ir/statements.cpp | 8 +- taichi/ir/statements.h | 4 +- taichi/lang_util.cpp | 93 +++++++++--------- taichi/lang_util.h | 96 ++++++++++--------- taichi/llvm/llvm_context.cpp | 32 +++---- taichi/program/compile_config.cpp | 4 +- taichi/program/kernel.cpp | 80 ++++++++-------- taichi/program/kernel.h | 4 +- taichi/program/program.cpp | 8 +- taichi/python/export_lang.cpp | 7 +- taichi/transforms/alg_simp.cpp | 4 +- taichi/transforms/constant_fold.cpp | 4 +- .../transforms/demote_dense_struct_fors.cpp | 2 +- taichi/transforms/lower_ast.cpp | 17 ++-- taichi/transforms/simplify.cpp | 12 +-- taichi/transforms/type_check.cpp | 51 +++++----- tests/cpp/test_alg_simp.cpp | 40 ++++---- tests/cpp/test_same_statements.cpp | 8 +- 34 files changed, 361 insertions(+), 344 deletions(-) diff --git a/docs/gui.rst b/docs/gui.rst index 7fd2353b78e51..5bf5d1c33aafb 100644 --- a/docs/gui.rst +++ b/docs/gui.rst @@ -68,7 +68,7 @@ Taichi's GUI supports painting simple geometric objects, such as lines, triangle .. note:: - The position parameter ``pos`` expects an input of a 2-element tuple, whose values are the relative position of the object. + The position parameter ``pos`` expects an input of a 2-element tuple, whose values are the relative position of the object. ``(0.0, 0.0)`` stands for the lower left corner of the window, and ``(1.0, 1.0)`` stands for the upper right corner. diff --git a/taichi/analysis/value_diff.cpp b/taichi/analysis/value_diff.cpp index c1e3eb08fe197..65dcead02e694 100644 --- a/taichi/analysis/value_diff.cpp +++ b/taichi/analysis/value_diff.cpp @@ -56,7 +56,7 @@ class ValueDiffLoopIndex : public IRVisitor { } void visit(ConstStmt *stmt) override { - if (stmt->val[lane].dt == DataType::i32) { + if (stmt->val[lane].dt == PrimitiveType::i32) { results[stmt->instance_id] = DiffRange(true, 0, stmt->val[lane].val_i32); } else { results[stmt->instance_id] = DiffRange(); @@ -112,7 +112,7 @@ class FindDirectValueBaseAndOffset : public IRVisitor { void visit(ConstStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - if (stmt->val[0].dt == DataType::i32) { + if (stmt->val[0].dt == PrimitiveType::i32) { result = std::make_tuple(true, nullptr, stmt->val[0].val_i32); } } diff --git a/taichi/backends/cc/codegen_cc.cpp b/taichi/backends/cc/codegen_cc.cpp index 2170057ebda42..d7ca78881411d 100644 --- a/taichi/backends/cc/codegen_cc.cpp +++ b/taichi/backends/cc/codegen_cc.cpp @@ -257,13 +257,13 @@ class CCTransformer : public IRVisitor { } static std::string _get_libc_function_name(std::string name, DataType dt) { - if (dt == DataType::i32) + if (dt == PrimitiveType::i32) return name; - else if (dt == DataType::i64) + else if (dt == PrimitiveType::i64) return "ll" + name; - else if (dt == DataType::f32) + else if (dt == PrimitiveType::f32) return name + "f"; - else if (dt == DataType::f64) + else if (dt == PrimitiveType::f64) return name; else TI_ERROR("Unsupported function \"{}\" for DataType={} on C backend", name, diff --git a/taichi/backends/cpu/codegen_cpu.cpp b/taichi/backends/cpu/codegen_cpu.cpp index 71e91c365a7c2..0c28960be9d32 100644 --- a/taichi/backends/cpu/codegen_cpu.cpp +++ b/taichi/backends/cpu/codegen_cpu.cpp @@ -37,7 +37,7 @@ class CodeGenLLVMCPU : public CodeGenLLVM { llvm::Type::getInt8PtrTy(*llvm_context), tlctx->get_data_type()}); - auto loop_var = create_entry_block_alloca(DataType::i32); + auto loop_var = create_entry_block_alloca(PrimitiveType::i32); loop_vars_llvm[stmt].push_back(loop_var); builder->CreateStore(get_arg(2), loop_var); stmt->body->accept(this); diff --git a/taichi/backends/cuda/codegen_cuda.cpp b/taichi/backends/cuda/codegen_cuda.cpp index 85de6e7b63e6a..1c19ccb581348 100644 --- a/taichi/backends/cuda/codegen_cuda.cpp +++ b/taichi/backends/cuda/codegen_cuda.cpp @@ -118,8 +118,8 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { auto value_type = tlctx->get_data_type(arg_stmt->ret_type.data_type); auto value = llvm_val[arg_stmt]; - if (arg_stmt->ret_type.data_type == DataType::f32) { - value_type = tlctx->get_data_type(DataType::f64); + if (arg_stmt->ret_type.data_type == PrimitiveType::f32) { + value_type = tlctx->get_data_type(PrimitiveType::f64); value = builder->CreateFPExt(value, value_type); } @@ -162,43 +162,43 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { #define UNARY_STD(x) \ else if (op == UnaryOpType::x) { \ - if (input_taichi_type == DataType::f32) { \ + if (input_taichi_type == PrimitiveType::f32) { \ llvm_val[stmt] = \ builder->CreateCall(get_runtime_function("__nv_" #x "f"), input); \ - } else if (input_taichi_type == DataType::f64) { \ + } else if (input_taichi_type == PrimitiveType::f64) { \ llvm_val[stmt] = \ builder->CreateCall(get_runtime_function("__nv_" #x), input); \ - } else if (input_taichi_type == DataType::i32) { \ + } else if (input_taichi_type == PrimitiveType::i32) { \ llvm_val[stmt] = builder->CreateCall(get_runtime_function(#x), input); \ } else { \ TI_NOT_IMPLEMENTED \ } \ } if (op == UnaryOpType::abs) { - if (input_taichi_type == DataType::f32) { + if (input_taichi_type == PrimitiveType::f32) { llvm_val[stmt] = builder->CreateCall(get_runtime_function("__nv_fabsf"), input); - } else if (input_taichi_type == DataType::f64) { + } else if (input_taichi_type == PrimitiveType::f64) { llvm_val[stmt] = builder->CreateCall(get_runtime_function("__nv_fabs"), input); - } else if (input_taichi_type == DataType::i32) { + } else if (input_taichi_type == PrimitiveType::i32) { llvm_val[stmt] = builder->CreateCall(get_runtime_function("__nv_abs"), input); } else { TI_NOT_IMPLEMENTED } } else if (op == UnaryOpType::sqrt) { - if (input_taichi_type == DataType::f32) { + if (input_taichi_type == PrimitiveType::f32) { llvm_val[stmt] = builder->CreateCall(get_runtime_function("__nv_sqrtf"), input); - } else if (input_taichi_type == DataType::f64) { + } else if (input_taichi_type == PrimitiveType::f64) { llvm_val[stmt] = builder->CreateCall(get_runtime_function("__nv_sqrt"), input); } else { TI_NOT_IMPLEMENTED } } else if (op == UnaryOpType::logic_not) { - if (input_taichi_type == DataType::i32) { + if (input_taichi_type == PrimitiveType::i32) { llvm_val[stmt] = builder->CreateCall(get_runtime_function("logic_not_i32"), input); } else { @@ -236,11 +236,11 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { llvm::AtomicRMWInst::BinOp::Add, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == DataType::f32) { + } else if (stmt->val->ret_type.data_type == PrimitiveType::f32) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::FAdd, llvm_val[stmt->dest], llvm_val[stmt->val], AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == DataType::f64) { + } else if (stmt->val->ret_type.data_type == PrimitiveType::f64) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::FAdd, llvm_val[stmt->dest], llvm_val[stmt->val], AtomicOrdering::SequentiallyConsistent); @@ -253,11 +253,11 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { llvm::AtomicRMWInst::BinOp::Min, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == DataType::f32) { + } else if (stmt->val->ret_type.data_type == PrimitiveType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_min_f32"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); - } else if (stmt->val->ret_type.data_type == DataType::f64) { + } else if (stmt->val->ret_type.data_type == PrimitiveType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_min_f64"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); @@ -270,11 +270,11 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { llvm::AtomicRMWInst::BinOp::Max, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == DataType::f32) { + } else if (stmt->val->ret_type.data_type == PrimitiveType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_max_f32"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); - } else if (stmt->val->ret_type.data_type == DataType::f64) { + } else if (stmt->val->ret_type.data_type == PrimitiveType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_max_f64"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); @@ -334,7 +334,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { {llvm::PointerType::get(get_runtime_type("Context"), 0), get_tls_buffer_type(), tlctx->get_data_type()}); - auto loop_var = create_entry_block_alloca(DataType::i32); + auto loop_var = create_entry_block_alloca(PrimitiveType::i32); loop_vars_llvm[stmt].push_back(loop_var); builder->CreateStore(get_arg(2), loop_var); stmt->body->accept(this); diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index 0b748abfc221b..aa49f1a1e05e1 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -275,7 +275,7 @@ class KernelCodegen : public IRVisitor { } } else if (opty == SNodeOpType::append) { TI_ASSERT(is_dynamic); - TI_ASSERT(stmt->ret_type.data_type == DataType::i32); + TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32); emit("{} = {}.append({});", result_var, parent, stmt->val->raw_name()); } else if (opty == SNodeOpType::length) { TI_ASSERT(is_dynamic); @@ -485,19 +485,19 @@ class KernelCodegen : public IRVisitor { current_appender().push_indent(); } - if (dt == DataType::i32) { + if (dt == PrimitiveType::i32) { emit( "const auto {} = atomic_fetch_{}_explicit((device atomic_int*){}, " "{}, " "metal::memory_order_relaxed);", stmt->raw_name(), op_name, stmt->dest->raw_name(), val_var); - } else if (dt == DataType::u32) { + } else if (dt == PrimitiveType::u32) { emit( "const auto {} = atomic_fetch_{}_explicit((device atomic_uint*){}, " "{}, " "metal::memory_order_relaxed);", stmt->raw_name(), op_name, stmt->dest->raw_name(), val_var); - } else if (dt == DataType::f32) { + } else if (dt == PrimitiveType::f32) { if (handle_float) { emit("const float {} = fatomic_fetch_{}({}, {});", stmt->raw_name(), op_name, stmt->dest->raw_name(), val_var); @@ -624,7 +624,7 @@ class KernelCodegen : public IRVisitor { if (std::holds_alternative(entry)) { auto *arg_stmt = std::get(entry); const auto dt = arg_stmt->element_type(); - TI_ASSERT_INFO(dt == DataType::i32 || dt == DataType::f32, + TI_ASSERT_INFO(dt == PrimitiveType::i32 || dt == PrimitiveType::f32, "print() only supports i32 or f32 scalars for now."); emit("{}.pm_set_{}({}, {});", msg_var_name, data_type_short_name(dt), i, arg_stmt->raw_name()); @@ -1037,7 +1037,8 @@ class KernelCodegen : public IRVisitor { used_features()->sparse = true; } - std::string inject_load_global_tmp(int offset, DataType dt = DataType::i32) { + std::string inject_load_global_tmp(int offset, + DataType dt = PrimitiveType::i32) { const auto vt = VectorType(/*width=*/1, dt); auto gtmp = Stmt::make(offset, vt); gtmp->accept(this); diff --git a/taichi/backends/metal/data_types.cpp b/taichi/backends/metal/data_types.cpp index 11fd443ec666c..101251a59a236 100644 --- a/taichi/backends/metal/data_types.cpp +++ b/taichi/backends/metal/data_types.cpp @@ -4,7 +4,7 @@ TLANG_NAMESPACE_BEGIN namespace metal { MetalDataType to_metal_type(DataType dt) { -#define METAL_CASE(x) else if (dt == DataType::x) return MetalDataType::x +#define METAL_CASE(x) else if (dt == PrimitiveType::x) return MetalDataType::x if (false) { } METAL_CASE(f32); diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index 54e35beb01ae0..273bb881d416c 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -104,22 +104,22 @@ class KernelGen : public IRVisitor { // Note that the following two functions not only returns the corresponding // data type, but also **records** the usage of `i64` and `f64`. std::string opengl_data_type_short_name(DataType dt) { - if (dt == DataType::i64) { + if (dt == PrimitiveType::i64) { if (!TI_OPENGL_REQUIRE(used, GL_ARB_gpu_shader_int64)) { TI_ERROR( "Extension GL_ARB_gpu_shader_int64 not supported on your OpenGL"); } used.int64 = true; } - if (dt == DataType::f64) + if (dt == PrimitiveType::f64) used.float64 = true; return data_type_short_name(dt); } std::string opengl_data_type_name(DataType dt) { - if (dt == DataType::i64) + if (dt == PrimitiveType::i64) used.int64 = true; - if (dt == DataType::f64) + if (dt == PrimitiveType::f64) used.float64 = true; return opengl::opengl_data_type_name(dt); } @@ -360,7 +360,7 @@ class KernelGen : public IRVisitor { } } else if (stmt->op_type == SNodeOpType::is_active) { - TI_ASSERT(stmt->ret_type.data_type == DataType::i32); + TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32); if (stmt->snode->type == SNodeType::dense || stmt->snode->type == SNodeType::root) { emit("int {} = 1;", stmt->short_name()); @@ -373,7 +373,7 @@ class KernelGen : public IRVisitor { } else if (stmt->op_type == SNodeOpType::append) { TI_ASSERT(stmt->snode->type == SNodeType::dynamic); - TI_ASSERT(stmt->ret_type.data_type == DataType::i32); + TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32); emit("int {} = atomicAdd(_data_i32_[{} >> 2], 1);", stmt->short_name(), get_snode_meta_address(stmt->snode)); auto dt = stmt->val->element_type(); @@ -387,7 +387,7 @@ class KernelGen : public IRVisitor { } else if (stmt->op_type == SNodeOpType::length) { TI_ASSERT(stmt->snode->type == SNodeType::dynamic); - TI_ASSERT(stmt->ret_type.data_type == DataType::i32); + TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32); emit("int {} = _data_i32_[{} >> 2];", stmt->short_name(), get_snode_meta_address(stmt->snode)); @@ -479,12 +479,12 @@ class KernelGen : public IRVisitor { emit("{} {} = {}({});", dt_name, stmt->short_name(), opengl_data_type_name(stmt->cast_type), stmt->operand->short_name()); } else if (stmt->op_type == UnaryOpType::cast_bits) { - if (stmt->cast_type == DataType::f32 && - stmt->operand->element_type() == DataType::i32) { + if (stmt->cast_type == PrimitiveType::f32 && + stmt->operand->element_type() == PrimitiveType::i32) { emit("{} {} = intBitsToFloat({});", dt_name, stmt->short_name(), stmt->operand->short_name()); - } else if (stmt->cast_type == DataType::i32 && - stmt->operand->element_type() == DataType::f32) { + } else if (stmt->cast_type == PrimitiveType::i32 && + stmt->operand->element_type() == PrimitiveType::f32) { emit("{} {} = floatBitsToInt({});", dt_name, stmt->short_name(), stmt->operand->short_name()); } else { @@ -527,7 +527,7 @@ class KernelGen : public IRVisitor { return; } else if (bin->op_type == BinaryOpType::atan2) { if (bin->element_type() == - DataType::f64) { // don't know why no atan(double, double) + PrimitiveType::f64) { // don't know why no atan(double, double) emit("{} {} = {}(atan(float({}), float({})));", dt_name, bin_name, dt_name, lhs_name, rhs_name); } else { @@ -573,15 +573,15 @@ class KernelGen : public IRVisitor { void visit(AtomicOpStmt *stmt) override { TI_ASSERT(stmt->width() == 1); auto dt = stmt->dest->element_type(); - if (dt == DataType::i32 || + if (dt == PrimitiveType::i32 || (TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_int64) && - dt == DataType::i64) || + dt == PrimitiveType::i64) || ((stmt->op_type == AtomicOpType::add || stmt->op_type == AtomicOpType::sub) && ((TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_float) && - dt == DataType::f32) || + dt == PrimitiveType::f32) || (TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_float64) && - dt == DataType::f64)))) { + dt == PrimitiveType::f64)))) { emit("{} {} = {}(_{}_{}_[{} >> {}], {});", opengl_data_type_name(stmt->val->element_type()), stmt->short_name(), opengl_atomic_op_type_cap_name(stmt->op_type), @@ -589,9 +589,9 @@ class KernelGen : public IRVisitor { stmt->dest->short_name(), opengl_data_address_shifter(dt), stmt->val->short_name()); } else { - if (dt != DataType::f32) { + if (dt != PrimitiveType::f32) { TI_ERROR( - "unsupported atomic operation for DataType::{}, " + "unsupported atomic operation for PrimitiveType::{}, " "this may because your OpenGL is missing that extension, " "see `glewinfo` for more details", opengl_data_type_short_name(dt)); diff --git a/taichi/backends/opengl/opengl_data_types.h b/taichi/backends/opengl/opengl_data_types.h index 2834db90a992a..e80018d9fc36d 100644 --- a/taichi/backends/opengl/opengl_data_types.h +++ b/taichi/backends/opengl/opengl_data_types.h @@ -8,13 +8,13 @@ namespace opengl { inline std::string opengl_data_type_name(DataType dt) { // https://www.khronos.org/opengl/wiki/Data_Type_(GLSL) - if (dt == DataType::f32) + if (dt == PrimitiveType::f32) return "float"; - else if (dt == DataType::f64) + else if (dt == PrimitiveType::f64) return "double"; - else if (dt == DataType::i32) + else if (dt == PrimitiveType::i32) return "int"; - else if (dt == DataType::i64) + else if (dt == PrimitiveType::i64) return "int64_t"; else TI_NOT_IMPLEMENTED; @@ -32,9 +32,9 @@ inline bool is_opengl_binary_op_different_return_type(BinaryOpType type) { } inline int opengl_data_address_shifter(DataType type) { - if (type == DataType::f32 || type == DataType::i32) + if (type == PrimitiveType::f32 || type == PrimitiveType::i32) return 2; - else if (type == DataType::f64 || type == DataType::i64) { + else if (type == PrimitiveType::f64 || type == PrimitiveType::i64) { return 3; } else { TI_NOT_IMPLEMENTED diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 4bbc4efe46995..dc2337ecbd779 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -149,13 +149,13 @@ void CodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) { #define UNARY_STD(x) \ else if (op == UnaryOpType::x) { \ - if (input_taichi_type == DataType::f32) { \ + if (input_taichi_type == PrimitiveType::f32) { \ llvm_val[stmt] = \ builder->CreateCall(get_runtime_function(#x "_f32"), input); \ - } else if (input_taichi_type == DataType::f64) { \ + } else if (input_taichi_type == PrimitiveType::f64) { \ llvm_val[stmt] = \ builder->CreateCall(get_runtime_function(#x "_f64"), input); \ - } else if (input_taichi_type == DataType::i32) { \ + } else if (input_taichi_type == PrimitiveType::i32) { \ llvm_val[stmt] = \ builder->CreateCall(get_runtime_function(#x "_i32"), input); \ } else { \ @@ -427,7 +427,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { if (is_real(ret_type)) { llvm_val[stmt] = builder->CreateMaxNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); - } else if (ret_type == DataType::i32) { + } else if (ret_type == PrimitiveType::i32) { llvm_val[stmt] = create_call("max_i32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else { @@ -436,10 +436,10 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } } else if (op == BinaryOpType::atan2) { if (arch_is_cpu(current_arch())) { - if (ret_type == DataType::f32) { + if (ret_type == PrimitiveType::f32) { llvm_val[stmt] = create_call( "atan2_f32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type == DataType::f64) { + } else if (ret_type == PrimitiveType::f64) { llvm_val[stmt] = create_call( "atan2_f64", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else { @@ -447,10 +447,10 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { TI_NOT_IMPLEMENTED } } else if (current_arch() == Arch::cuda) { - if (ret_type == DataType::f32) { + if (ret_type == PrimitiveType::f32) { llvm_val[stmt] = create_call( "__nv_atan2f", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type == DataType::f64) { + } else if (ret_type == PrimitiveType::f64) { llvm_val[stmt] = create_call( "__nv_atan2", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else { @@ -462,16 +462,16 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } } else if (op == BinaryOpType::pow) { if (arch_is_cpu(current_arch())) { - if (ret_type == DataType::f32) { + if (ret_type == PrimitiveType::f32) { llvm_val[stmt] = create_call("pow_f32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type == DataType::f64) { + } else if (ret_type == PrimitiveType::f64) { llvm_val[stmt] = create_call("pow_f64", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type == DataType::i32) { + } else if (ret_type == PrimitiveType::i32) { llvm_val[stmt] = create_call("pow_i32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type == DataType::i64) { + } else if (ret_type == PrimitiveType::i64) { llvm_val[stmt] = create_call("pow_i64", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else { @@ -479,16 +479,16 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { TI_NOT_IMPLEMENTED } } else if (current_arch() == Arch::cuda) { - if (ret_type == DataType::f32) { + if (ret_type == PrimitiveType::f32) { llvm_val[stmt] = create_call( "__nv_powf", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type == DataType::f64) { + } else if (ret_type == PrimitiveType::f64) { llvm_val[stmt] = create_call("__nv_pow", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type == DataType::i32) { + } else if (ret_type == PrimitiveType::i32) { llvm_val[stmt] = create_call("pow_i32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); - } else if (ret_type == DataType::i64) { + } else if (ret_type == PrimitiveType::i64) { llvm_val[stmt] = create_call("pow_i64", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else { @@ -502,7 +502,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { if (is_real(ret_type)) { llvm_val[stmt] = builder->CreateMinNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); - } else if (ret_type == DataType::i32) { + } else if (ret_type == PrimitiveType::i32) { llvm_val[stmt] = create_call("min_i32", {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); } else { @@ -575,7 +575,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } else { TI_NOT_IMPLEMENTED } - llvm_val[stmt] = builder->CreateSExt(cmp, llvm_type(DataType::i32)); + llvm_val[stmt] = builder->CreateSExt(cmp, llvm_type(PrimitiveType::i32)); } else { TI_P(binary_op_type_name(op)); TI_NOT_IMPLEMENTED @@ -583,13 +583,13 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } llvm::Type *CodeGenLLVM::llvm_type(DataType dt) { - if (dt == DataType::i32) { + if (dt == PrimitiveType::i32) { return llvm::Type::getInt32Ty(*llvm_context); - } else if (dt == DataType::u1) { + } else if (dt == PrimitiveType::u1) { return llvm::Type::getInt1Ty(*llvm_context); - } else if (dt == DataType::f32) { + } else if (dt == PrimitiveType::f32) { return llvm::Type::getFloatTy(*llvm_context); - } else if (dt == DataType::f64) { + } else if (dt == PrimitiveType::f64) { return llvm::Type::getDoubleTy(*llvm_context); } else { TI_NOT_IMPLEMENTED; @@ -600,7 +600,7 @@ llvm::Type *CodeGenLLVM::llvm_type(DataType dt) { void CodeGenLLVM::visit(TernaryOpStmt *stmt) { TI_ASSERT(stmt->op_type == TernaryOpType::select); llvm_val[stmt] = builder->CreateSelect( - builder->CreateTrunc(llvm_val[stmt->op1], llvm_type(DataType::u1)), + builder->CreateTrunc(llvm_val[stmt->op1], llvm_type(PrimitiveType::u1)), llvm_val[stmt->op2], llvm_val[stmt->op3]); } @@ -637,8 +637,9 @@ llvm::Value *CodeGenLLVM::create_print(std::string tag, args.push_back(builder->CreateGlobalStringPtr( ("[llvm codegen debug] " + tag + " = " + format + "\n").c_str(), "format_string")); - if (dt == DataType::f32) - value = builder->CreateFPExt(value, tlctx->get_data_type(DataType::f64)); + if (dt == PrimitiveType::f32) + value = + builder->CreateFPExt(value, tlctx->get_data_type(PrimitiveType::f64)); args.push_back(value); return builder->CreateCall(runtime_printf, args); } @@ -651,9 +652,9 @@ void CodeGenLLVM::visit(PrintStmt *stmt) { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); auto value = llvm_val[arg_stmt]; - if (arg_stmt->ret_type.data_type == DataType::f32) - value = - builder->CreateFPExt(value, tlctx->get_data_type(DataType::f64)); + if (arg_stmt->ret_type.data_type == PrimitiveType::f32) + value = builder->CreateFPExt(value, + tlctx->get_data_type(PrimitiveType::f64)); args.push_back(value); formats += data_type_format(arg_stmt->ret_type.data_type); } else { @@ -673,22 +674,22 @@ void CodeGenLLVM::visit(PrintStmt *stmt) { void CodeGenLLVM::visit(ConstStmt *stmt) { TI_ASSERT(stmt->width() == 1); auto val = stmt->val[0]; - if (val.dt == DataType::f32) { + if (val.dt == PrimitiveType::f32) { llvm_val[stmt] = llvm::ConstantFP::get(*llvm_context, llvm::APFloat(val.val_float32())); - } else if (val.dt == DataType::f64) { + } else if (val.dt == PrimitiveType::f64) { llvm_val[stmt] = llvm::ConstantFP::get(*llvm_context, llvm::APFloat(val.val_float64())); - } else if (val.dt == DataType::i32) { + } else if (val.dt == PrimitiveType::i32) { llvm_val[stmt] = llvm::ConstantInt::get( *llvm_context, llvm::APInt(32, (uint64)val.val_int32(), true)); - } else if (val.dt == DataType::u32) { + } else if (val.dt == PrimitiveType::u32) { llvm_val[stmt] = llvm::ConstantInt::get( *llvm_context, llvm::APInt(32, (uint64)val.val_uint32(), false)); - } else if (val.dt == DataType::i64) { + } else if (val.dt == PrimitiveType::i64) { llvm_val[stmt] = llvm::ConstantInt::get( *llvm_context, llvm::APInt(64, (uint64)val.val_int64(), true)); - } else if (val.dt == DataType::u64) { + } else if (val.dt == PrimitiveType::u64) { llvm_val[stmt] = llvm::ConstantInt::get( *llvm_context, llvm::APInt(64, val.val_uint64(), false)); } else { @@ -790,7 +791,7 @@ void CodeGenLLVM::create_naive_range_for(RangeForStmt *for_stmt) { BasicBlock *loop_test = BasicBlock::Create(*llvm_context, "for_loop_test", func); - auto loop_var = create_entry_block_alloca(DataType::i32); + auto loop_var = create_entry_block_alloca(PrimitiveType::i32); loop_vars_llvm[for_stmt].push_back(loop_var); if (!for_stmt->reversed) { @@ -854,7 +855,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) { llvm::Type *dest_ty = nullptr; if (stmt->is_ptr) { - dest_ty = PointerType::get(tlctx->get_data_type(DataType::i32), 0); + dest_ty = PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0); llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty); } else { dest_ty = tlctx->get_data_type(stmt->ret_type.data_type); @@ -942,7 +943,7 @@ void CodeGenLLVM::visit(SNodeOpStmt *stmt) { auto snode = stmt->snode; if (stmt->op_type == SNodeOpType::append) { TI_ASSERT(snode->type == SNodeType::dynamic); - TI_ASSERT(stmt->ret_type.data_type == DataType::i32); + TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32); llvm_val[stmt] = call(snode, llvm_val[stmt->ptr], "append", {llvm_val[stmt->val]}); } else if (stmt->op_type == SNodeOpType::length) { @@ -978,11 +979,11 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Add, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == DataType::f32) { + } else if (stmt->val->ret_type.data_type == PrimitiveType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_add_f32"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); - } else if (stmt->val->ret_type.data_type == DataType::f64) { + } else if (stmt->val->ret_type.data_type == PrimitiveType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_add_f64"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); @@ -994,11 +995,11 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Min, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == DataType::f32) { + } else if (stmt->val->ret_type.data_type == PrimitiveType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_min_f32"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); - } else if (stmt->val->ret_type.data_type == DataType::f64) { + } else if (stmt->val->ret_type.data_type == PrimitiveType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_min_f64"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); @@ -1010,11 +1011,11 @@ void CodeGenLLVM::visit(AtomicOpStmt *stmt) { old_value = builder->CreateAtomicRMW( llvm::AtomicRMWInst::BinOp::Max, llvm_val[stmt->dest], llvm_val[stmt->val], llvm::AtomicOrdering::SequentiallyConsistent); - } else if (stmt->val->ret_type.data_type == DataType::f32) { + } else if (stmt->val->ret_type.data_type == PrimitiveType::f32) { old_value = builder->CreateCall(get_runtime_function("atomic_max_f32"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); - } else if (stmt->val->ret_type.data_type == DataType::f64) { + } else if (stmt->val->ret_type.data_type == PrimitiveType::f64) { old_value = builder->CreateCall(get_runtime_function("atomic_max_f64"), {llvm_val[stmt->dest], llvm_val[stmt->val]}); @@ -1296,7 +1297,7 @@ std::tuple CodeGenLLVM::get_range_for_bounds( begin = tlctx->get_constant(stmt->begin_value); } else { auto begin_stmt = Stmt::make( - stmt->begin_offset, VectorType(1, DataType::i32)); + stmt->begin_offset, VectorType(1, PrimitiveType::i32)); begin_stmt->accept(this); begin = builder->CreateLoad(llvm_val[begin_stmt.get()]); } @@ -1304,7 +1305,7 @@ std::tuple CodeGenLLVM::get_range_for_bounds( end = tlctx->get_constant(stmt->end_value); } else { auto end_stmt = Stmt::make( - stmt->end_offset, VectorType(1, DataType::i32)); + stmt->end_offset, VectorType(1, PrimitiveType::i32)); end_stmt->accept(this); end = builder->CreateLoad(llvm_val[end_stmt.get()]); } diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index ee506326fabda..b8a70ccdfb6dd 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -194,7 +194,8 @@ Expr ptr_if_global(const Expr &var) { Expr Var(const Expr &x) { auto var = Expr(std::make_shared()); current_ast_builder().insert(std::make_unique( - std::static_pointer_cast(var.expr)->id, DataType::unknown)); + std::static_pointer_cast(var.expr)->id, + PrimitiveType::unknown)); var = x; return var; } diff --git a/taichi/ir/frontend.h b/taichi/ir/frontend.h index be75194d99d74..dad9337464d81 100644 --- a/taichi/ir/frontend.h +++ b/taichi/ir/frontend.h @@ -57,7 +57,8 @@ inline void declare_unnamed_var(Expr &a, DataType dt) { inline void declare_var(Expr &a) { current_ast_builder().insert(std::make_unique( - std::static_pointer_cast(a.expr)->id, DataType::unknown)); + std::static_pointer_cast(a.expr)->id, + PrimitiveType::unknown)); } #define Declare(x) auto x = Expr(std::make_shared(#x)); @@ -66,15 +67,15 @@ inline void declare_var(Expr &a) { #define NamedScalar(x, name, dt) \ DeclareNamed(x##_global, #name); \ - auto x = global_new(x##_global, DataType::dt); + auto x = global_new(x##_global, PrimitiveType::dt); #define Global(x, dt) \ Declare(x##_global); \ - auto x = global_new(x##_global, DataType::dt); + auto x = global_new(x##_global, PrimitiveType::dt); -#define AmbientGlobal(x, dt, ambient) \ - Declare(x##_global); \ - auto x = global_new(x##_global, DataType::dt); \ +#define AmbientGlobal(x, dt, ambient) \ + Declare(x##_global); \ + auto x = global_new(x##_global, PrimitiveType::dt); \ set_ambient(x, ambient); inline void set_ambient(Expr expr_, float32 val) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 3c757b1dda655..f773e593065f0 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -252,7 +252,7 @@ class UnaryOpExpression : public Expression { UnaryOpExpression(UnaryOpType type, const Expr &operand) : type(type), operand(smart_load(operand)) { - cast_type = DataType::unknown; + cast_type = PrimitiveType::unknown; } bool is_cast() const; diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 0a401d2ed1476..e01c0f40b6356 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -237,7 +237,7 @@ void Stmt::replace_operand_with(Stmt *old_stmt, Stmt *new_stmt) { } std::string Stmt::type_hint() const { - if (ret_type.data_type == DataType::unknown) + if (ret_type.data_type == PrimitiveType::unknown) return ""; else return fmt::format("<{}>{}", ret_type.str(), is_ptr ? "ptr " : " "); diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index bbaf2e886a5e2..47b7588004be5 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -45,7 +45,8 @@ struct VectorType { : _is_pointer(is_pointer), width(width), data_type(data_type) { } - VectorType() : _is_pointer(false), width(1), data_type(DataType::unknown) { + VectorType() + : _is_pointer(false), width(1), data_type(PrimitiveType::unknown) { } bool operator==(const VectorType &o) const { diff --git a/taichi/ir/snode.cpp b/taichi/ir/snode.cpp index 3490f5da3c89c..91ac84608e139 100644 --- a/taichi/ir/snode.cpp +++ b/taichi/ir/snode.cpp @@ -211,7 +211,7 @@ SNode::SNode(int depth, SNodeType t) : depth(depth), type(t) { std::memset(physical_index_position, -1, sizeof(physical_index_position)); parent = nullptr; has_ambient = false; - dt = DataType::gen; + dt = PrimitiveType::gen; _morton = false; reader_kernel = nullptr; diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 6f4783322b87b..18b9460b765d5 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -18,7 +18,7 @@ bool ContinueStmt::as_return() const { UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, Stmt *operand) : op_type(op_type), operand(operand) { TI_ASSERT(!operand->is()); - cast_type = DataType::unknown; + cast_type = PrimitiveType::unknown; TI_STMT_REG_FIELDS; } @@ -40,7 +40,7 @@ bool UnaryOpStmt::same_operation(UnaryOpStmt *o) const { ExternalPtrStmt::ExternalPtrStmt(const LaneAttribute &base_ptrs, const std::vector &indices) : base_ptrs(base_ptrs), indices(indices) { - DataType dt = DataType::f32; + DataType dt = PrimitiveType::f32; for (int i = 0; i < (int)base_ptrs.size(); i++) { TI_ASSERT(base_ptrs[i] != nullptr); TI_ASSERT(base_ptrs[i]->is()); @@ -91,7 +91,7 @@ SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type, Stmt *val) : op_type(op_type), snode(snode), ptr(ptr), val(val) { width() = 1; - element_type() = DataType::i32; + element_type() = PrimitiveType::i32; TI_STMT_REG_FIELDS; } @@ -105,7 +105,7 @@ SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type, op_type == SNodeOpType::deactivate || op_type == SNodeOpType::activate); width() = 1; - element_type() = DataType::i32; + element_type() = PrimitiveType::i32; TI_STMT_REG_FIELDS; } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index f8124d6273446..3fb1b4399c2dd 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -60,7 +60,7 @@ class ContinueStmt : public Stmt { // The reason is that, each thread may handle more than one element, // depending on the backend's implementation. // - // For example, CUDA uses gride-stride loops, the snippet below illustrates + // For example, CUDA uses grid-stride loops, the snippet below illustrates // the idea: // // __global__ foo_kernel(...) { @@ -1004,7 +1004,7 @@ class InternalFuncStmt : public Stmt { std::string func_name; InternalFuncStmt(const std::string &func_name) : func_name(func_name) { - this->ret_type = VectorType(1, DataType::i32); + this->ret_type = VectorType(1, PrimitiveType::i32); TI_STMT_REG_FIELDS; } diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index 31147e83c63ef..381b21214a40b 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -30,17 +30,21 @@ real get_cpu_frequency() { real default_measurement_time = 1; // Note: these primitive types should never be freed. They are supposed to live -// together with the process. -#define PER_TYPE(x) \ - DataType DataType::x = \ +// together with the process. This is a temporary solution. Later we should +// manage its ownership more systematically +#define PER_TYPE(x) \ + DataType PrimitiveType::x = \ DataType(new PrimitiveType(PrimitiveType::primitive_type::x)); #include "taichi/inc/data_type.inc.h" #undef PER_TYPE +DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) { +} + DataType PrimitiveType::get(PrimitiveType::primitive_type t) { if (false) { } -#define PER_TYPE(x) else if (t == primitive_type::x) return DataType::x; +#define PER_TYPE(x) else if (t == primitive_type::x) return PrimitiveType::x; #include "taichi/inc/data_type.inc.h" #undef PER_TYPE else { @@ -96,7 +100,7 @@ real measure_cpe(std::function target, } std::string data_type_name(DataType t) { -#define REGISTER_DATA_TYPE(i, j) else if (t == DataType::i) return #j +#define REGISTER_DATA_TYPE(i, j) else if (t == PrimitiveType::i) return #j if (false) { } REGISTER_DATA_TYPE(f16, float16); @@ -119,17 +123,17 @@ std::string data_type_name(DataType t) { } std::string data_type_format(DataType dt) { - if (dt == DataType::i32) { + if (dt == PrimitiveType::i32) { return "%d"; - } else if (dt == DataType::i64) { + } else if (dt == PrimitiveType::i64) { #if defined(TI_PLATFORM_UNIX) return "%lld"; #else return "%I64d"; #endif - } else if (dt == DataType::f32) { + } else if (dt == PrimitiveType::f32) { return "%f"; - } else if (dt == DataType::f64) { + } else if (dt == PrimitiveType::f64) { return "%.12f"; } else { TI_NOT_IMPLEMENTED @@ -138,14 +142,15 @@ std::string data_type_format(DataType dt) { int data_type_size(DataType t) { if (false) { - } else if (t == DataType::f16) + } else if (t == PrimitiveType::f16) return 2; - else if (t == DataType::gen) + else if (t == PrimitiveType::gen) return 0; - else if (t == DataType::unknown) + else if (t == PrimitiveType::unknown) return -1; -#define REGISTER_DATA_TYPE(i, j) else if (t == DataType::i) return sizeof(j) +#define REGISTER_DATA_TYPE(i, j) \ + else if (t == PrimitiveType::i) return sizeof(j) REGISTER_DATA_TYPE(f32, float32); REGISTER_DATA_TYPE(f64, float64); @@ -167,7 +172,7 @@ int data_type_size(DataType t) { std::string data_type_short_name(DataType t) { if (false) { } -#define PER_TYPE(i) else if (t == DataType::i) return #i; +#define PER_TYPE(i) else if (t == PrimitiveType::i) return #i; #include "taichi/inc/data_type.inc.h" #undef PER_TYPE else @@ -399,25 +404,25 @@ DataType promoted_type(DataType a, DataType b) { } std::string TypedConstant::stringify() const { - if (dt == DataType::f32) { + if (dt == PrimitiveType::f32) { return fmt::format("{}", val_f32); - } else if (dt == DataType::i32) { + } else if (dt == PrimitiveType::i32) { return fmt::format("{}", val_i32); - } else if (dt == DataType::i64) { + } else if (dt == PrimitiveType::i64) { return fmt::format("{}", val_i64); - } else if (dt == DataType::f64) { + } else if (dt == PrimitiveType::f64) { return fmt::format("{}", val_f64); - } else if (dt == DataType::i8) { + } else if (dt == PrimitiveType::i8) { return fmt::format("{}", val_i8); - } else if (dt == DataType::i16) { + } else if (dt == PrimitiveType::i16) { return fmt::format("{}", val_i16); - } else if (dt == DataType::u8) { + } else if (dt == PrimitiveType::u8) { return fmt::format("{}", val_u8); - } else if (dt == DataType::u16) { + } else if (dt == PrimitiveType::u16) { return fmt::format("{}", val_u16); - } else if (dt == DataType::u32) { + } else if (dt == PrimitiveType::u32) { return fmt::format("{}", val_u32); - } else if (dt == DataType::u64) { + } else if (dt == PrimitiveType::u64) { return fmt::format("{}", val_u64); } else { TI_P(data_type_name(dt)); @@ -429,25 +434,25 @@ std::string TypedConstant::stringify() const { bool TypedConstant::equal_type_and_value(const TypedConstant &o) const { if (dt != o.dt) return false; - if (dt == DataType::f32) { + if (dt == PrimitiveType::f32) { return val_f32 == o.val_f32; - } else if (dt == DataType::i32) { + } else if (dt == PrimitiveType::i32) { return val_i32 == o.val_i32; - } else if (dt == DataType::i64) { + } else if (dt == PrimitiveType::i64) { return val_i64 == o.val_i64; - } else if (dt == DataType::f64) { + } else if (dt == PrimitiveType::f64) { return val_f64 == o.val_f64; - } else if (dt == DataType::i8) { + } else if (dt == PrimitiveType::i8) { return val_i8 == o.val_i8; - } else if (dt == DataType::i16) { + } else if (dt == PrimitiveType::i16) { return val_i16 == o.val_i16; - } else if (dt == DataType::u8) { + } else if (dt == PrimitiveType::u8) { return val_u8 == o.val_u8; - } else if (dt == DataType::u16) { + } else if (dt == PrimitiveType::u16) { return val_u16 == o.val_u16; - } else if (dt == DataType::u32) { + } else if (dt == PrimitiveType::u32) { return val_u32 == o.val_u32; - } else if (dt == DataType::u64) { + } else if (dt == PrimitiveType::u64) { return val_u64 == o.val_u64; } else { TI_NOT_IMPLEMENTED @@ -507,13 +512,13 @@ uint64 &TypedConstant::val_uint64() { int64 TypedConstant::val_int() const { TI_ASSERT(is_signed(dt)); - if (dt == DataType::i32) { + if (dt == PrimitiveType::i32) { return val_i32; - } else if (dt == DataType::i64) { + } else if (dt == PrimitiveType::i64) { return val_i64; - } else if (dt == DataType::i8) { + } else if (dt == PrimitiveType::i8) { return val_i8; - } else if (dt == DataType::i16) { + } else if (dt == PrimitiveType::i16) { return val_i16; } else { TI_NOT_IMPLEMENTED @@ -522,13 +527,13 @@ int64 TypedConstant::val_int() const { uint64 TypedConstant::val_uint() const { TI_ASSERT(is_unsigned(dt)); - if (dt == DataType::u32) { + if (dt == PrimitiveType::u32) { return val_u32; - } else if (dt == DataType::u64) { + } else if (dt == PrimitiveType::u64) { return val_u64; - } else if (dt == DataType::u8) { + } else if (dt == PrimitiveType::u8) { return val_u8; - } else if (dt == DataType::u16) { + } else if (dt == PrimitiveType::u16) { return val_u16; } else { TI_NOT_IMPLEMENTED @@ -537,9 +542,9 @@ uint64 TypedConstant::val_uint() const { float64 TypedConstant::val_float() const { TI_ASSERT(is_real(dt)); - if (dt == DataType::f32) { + if (dt == PrimitiveType::f32) { return val_f32; - } else if (dt == DataType::f64) { + } else if (dt == PrimitiveType::f64) { return val_f64; } else { TI_NOT_IMPLEMENTED diff --git a/taichi/lang_util.h b/taichi/lang_util.h index b3572447716c5..d51a4bc3cff79 100644 --- a/taichi/lang_util.h +++ b/taichi/lang_util.h @@ -29,11 +29,7 @@ class Type { // A "Type" handle. This should be removed later. class DataType { public: -#define PER_TYPE(x) static DataType x; -#include "taichi/inc/data_type.inc.h" -#undef PER_TYPE - DataType() : ptr_(unknown.ptr_) { - } + DataType(); DataType(const Type *ptr) : ptr_(ptr) { } @@ -69,6 +65,10 @@ class PrimitiveType : public Type { #undef PER_TYPE }; +#define PER_TYPE(x) static DataType x; +#include "taichi/inc/data_type.inc.h" +#undef PER_TYPE + primitive_type type; PrimitiveType(primitive_type type) : type(type) { @@ -82,27 +82,27 @@ class PrimitiveType : public Type { template inline DataType get_data_type() { if (std::is_same()) { - return DataType::f32; + return PrimitiveType::f32; } else if (std::is_same()) { - return DataType::f64; + return PrimitiveType::f64; } else if (std::is_same()) { - return DataType::u1; + return PrimitiveType::u1; } else if (std::is_same()) { - return DataType::i8; + return PrimitiveType::i8; } else if (std::is_same()) { - return DataType::i16; + return PrimitiveType::i16; } else if (std::is_same()) { - return DataType::i32; + return PrimitiveType::i32; } else if (std::is_same()) { - return DataType::i64; + return PrimitiveType::i64; } else if (std::is_same()) { - return DataType::u8; + return PrimitiveType::u8; } else if (std::is_same()) { - return DataType::u16; + return PrimitiveType::u16; } else if (std::is_same()) { - return DataType::u32; + return PrimitiveType::u32; } else if (std::is_same()) { - return DataType::u64; + return PrimitiveType::u64; } else { TI_NOT_IMPLEMENTED; } @@ -143,19 +143,21 @@ inline bool constexpr is_trigonometric(UnaryOpType op) { } inline bool is_real(DataType dt) { - return dt == DataType::f16 || dt == DataType::f32 || dt == DataType::f64; + return dt == PrimitiveType::f16 || dt == PrimitiveType::f32 || + dt == PrimitiveType::f64; } inline bool is_integral(DataType dt) { - return dt == DataType::i8 || dt == DataType::i16 || dt == DataType::i32 || - dt == DataType::i64 || dt == DataType::u8 || dt == DataType::u16 || - dt == DataType::u32 || dt == DataType::u64; + return dt == PrimitiveType::i8 || dt == PrimitiveType::i16 || + dt == PrimitiveType::i32 || dt == PrimitiveType::i64 || + dt == PrimitiveType::u8 || dt == PrimitiveType::u16 || + dt == PrimitiveType::u32 || dt == PrimitiveType::u64; } inline bool is_signed(DataType dt) { TI_ASSERT(is_integral(dt)); - return dt == DataType::i8 || dt == DataType::i16 || dt == DataType::i32 || - dt == DataType::i64; + return dt == PrimitiveType::i8 || dt == PrimitiveType::i16 || + dt == PrimitiveType::i32 || dt == PrimitiveType::i64; } inline bool is_unsigned(DataType dt) { @@ -165,16 +167,16 @@ inline bool is_unsigned(DataType dt) { inline DataType to_unsigned(DataType dt) { TI_ASSERT(is_signed(dt)); - if (dt == DataType::i8) - return DataType::u8; - else if (dt == DataType::i16) - return DataType::u16; - else if (dt == DataType::i32) - return DataType::u32; - else if (dt == DataType::i64) - return DataType::u64; + if (dt == PrimitiveType::i8) + return PrimitiveType::u8; + else if (dt == PrimitiveType::i16) + return PrimitiveType::u16; + else if (dt == PrimitiveType::i32) + return PrimitiveType::u32; + else if (dt == PrimitiveType::i64) + return PrimitiveType::u64; else - return DataType::unknown; + return PrimitiveType::unknown; } inline bool needs_grad(DataType dt) { @@ -182,7 +184,7 @@ inline bool needs_grad(DataType dt) { } // Regular binary ops: -// Operations that take two oprands, and returns a single operand with the +// Operations that take two operands, and returns a single operand with the // same type enum class BinaryOpType : int { @@ -250,46 +252,46 @@ class TypedConstant { }; public: - TypedConstant() : dt(DataType::unknown) { + TypedConstant() : dt(PrimitiveType::unknown) { } TypedConstant(DataType dt) : dt(dt) { value_bits = 0; } - TypedConstant(int32 x) : dt(DataType::i32), val_i32(x) { + TypedConstant(int32 x) : dt(PrimitiveType::i32), val_i32(x) { } - TypedConstant(float32 x) : dt(DataType::f32), val_f32(x) { + TypedConstant(float32 x) : dt(PrimitiveType::f32), val_f32(x) { } - TypedConstant(int64 x) : dt(DataType::i64), val_i64(x) { + TypedConstant(int64 x) : dt(PrimitiveType::i64), val_i64(x) { } - TypedConstant(float64 x) : dt(DataType::f64), val_f64(x) { + TypedConstant(float64 x) : dt(PrimitiveType::f64), val_f64(x) { } template TypedConstant(DataType dt, const T &value) : dt(dt) { - if (dt == DataType::f32) { + if (dt == PrimitiveType::f32) { val_f32 = value; - } else if (dt == DataType::i32) { + } else if (dt == PrimitiveType::i32) { val_i32 = value; - } else if (dt == DataType::i64) { + } else if (dt == PrimitiveType::i64) { val_i64 = value; - } else if (dt == DataType::f64) { + } else if (dt == PrimitiveType::f64) { val_f64 = value; - } else if (dt == DataType::i8) { + } else if (dt == PrimitiveType::i8) { val_i8 = value; - } else if (dt == DataType::i16) { + } else if (dt == PrimitiveType::i16) { val_i16 = value; - } else if (dt == DataType::u8) { + } else if (dt == PrimitiveType::u8) { val_u8 = value; - } else if (dt == DataType::u16) { + } else if (dt == PrimitiveType::u16) { val_u16 = value; - } else if (dt == DataType::u32) { + } else if (dt == PrimitiveType::u32) { val_u32 = value; - } else if (dt == DataType::u64) { + } else if (dt == PrimitiveType::u64) { val_u64 = value; } else { TI_NOT_IMPLEMENTED diff --git a/taichi/llvm/llvm_context.cpp b/taichi/llvm/llvm_context.cpp index f3526493d7de1..ff119be6803be 100644 --- a/taichi/llvm/llvm_context.cpp +++ b/taichi/llvm/llvm_context.cpp @@ -87,25 +87,25 @@ TaichiLLVMContext::TaichiLLVMContext(Arch arch) : arch(arch) { llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { auto ctx = get_this_thread_context(); - if (dt == DataType::i32) { + if (dt == PrimitiveType::i32) { return llvm::Type::getInt32Ty(*ctx); - } else if (dt == DataType::i8) { + } else if (dt == PrimitiveType::i8) { return llvm::Type::getInt8Ty(*ctx); - } else if (dt == DataType::i16) { + } else if (dt == PrimitiveType::i16) { return llvm::Type::getInt16Ty(*ctx); - } else if (dt == DataType::i64) { + } else if (dt == PrimitiveType::i64) { return llvm::Type::getInt64Ty(*ctx); - } else if (dt == DataType::f32) { + } else if (dt == PrimitiveType::f32) { return llvm::Type::getFloatTy(*ctx); - } else if (dt == DataType::f64) { + } else if (dt == PrimitiveType::f64) { return llvm::Type::getDoubleTy(*ctx); - } else if (dt == DataType::u8) { + } else if (dt == PrimitiveType::u8) { return llvm::Type::getInt8Ty(*ctx); - } else if (dt == DataType::u16) { + } else if (dt == PrimitiveType::u16) { return llvm::Type::getInt16Ty(*ctx); - } else if (dt == DataType::u32) { + } else if (dt == PrimitiveType::u32) { return llvm::Type::getInt32Ty(*ctx); - } else if (dt == DataType::u64) { + } else if (dt == PrimitiveType::u64) { return llvm::Type::getInt64Ty(*ctx); } else { TI_INFO(data_type_name(dt)); @@ -560,17 +560,17 @@ void TaichiLLVMContext::set_struct_module( template llvm::Value *TaichiLLVMContext::get_constant(DataType dt, T t) { auto ctx = get_this_thread_context(); - if (dt == DataType::f32) { + if (dt == PrimitiveType::f32) { return llvm::ConstantFP::get(*ctx, llvm::APFloat((float32)t)); - } else if (dt == DataType::f64) { + } else if (dt == PrimitiveType::f64) { return llvm::ConstantFP::get(*ctx, llvm::APFloat((float64)t)); - } else if (dt == DataType::i32) { + } else if (dt == PrimitiveType::i32) { return llvm::ConstantInt::get(*ctx, llvm::APInt(32, t, true)); - } else if (dt == DataType::u32) { + } else if (dt == PrimitiveType::u32) { return llvm::ConstantInt::get(*ctx, llvm::APInt(32, t, false)); - } else if (dt == DataType::i64) { + } else if (dt == PrimitiveType::i64) { return llvm::ConstantInt::get(*ctx, llvm::APInt(64, t, true)); - } else if (dt == DataType::u64) { + } else if (dt == PrimitiveType::u64) { return llvm::ConstantInt::get(*ctx, llvm::APInt(64, t, false)); } else { TI_NOT_IMPLEMENTED diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index 415c02e7ce553..bfddb8d9bc1f4 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -23,8 +23,8 @@ CompileConfig::CompileConfig() { simplify_before_lower_access = true; lower_access = true; simplify_after_lower_access = true; - default_fp = DataType::f32; - default_ip = DataType::i32; + default_fp = PrimitiveType::f32; + default_ip = PrimitiveType::i32; verbose_kernel_launches = false; kernel_profiler = false; default_cpu_block_dim = 32; diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 7dde274576b1f..ebaeb9a072b71 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -149,25 +149,25 @@ void Kernel::LaunchContextBuilder::set_arg_float(int i, float64 d) { ActionArg("arg_id", i), ActionArg("val", d)}); auto dt = kernel_->args[i].dt; - if (dt == DataType::f32) { + if (dt == PrimitiveType::f32) { ctx_->set_arg(i, (float32)d); - } else if (dt == DataType::f64) { + } else if (dt == PrimitiveType::f64) { ctx_->set_arg(i, (float64)d); - } else if (dt == DataType::i32) { + } else if (dt == PrimitiveType::i32) { ctx_->set_arg(i, (int32)d); - } else if (dt == DataType::i64) { + } else if (dt == PrimitiveType::i64) { ctx_->set_arg(i, (int64)d); - } else if (dt == DataType::i8) { + } else if (dt == PrimitiveType::i8) { ctx_->set_arg(i, (int8)d); - } else if (dt == DataType::i16) { + } else if (dt == PrimitiveType::i16) { ctx_->set_arg(i, (int16)d); - } else if (dt == DataType::u8) { + } else if (dt == PrimitiveType::u8) { ctx_->set_arg(i, (uint8)d); - } else if (dt == DataType::u16) { + } else if (dt == PrimitiveType::u16) { ctx_->set_arg(i, (uint16)d); - } else if (dt == DataType::u32) { + } else if (dt == PrimitiveType::u32) { ctx_->set_arg(i, (uint32)d); - } else if (dt == DataType::u64) { + } else if (dt == PrimitiveType::u64) { ctx_->set_arg(i, (uint64)d); } else { TI_NOT_IMPLEMENTED @@ -184,25 +184,25 @@ void Kernel::LaunchContextBuilder::set_arg_int(int i, int64 d) { ActionArg("arg_id", i), ActionArg("val", d)}); auto dt = kernel_->args[i].dt; - if (dt == DataType::i32) { + if (dt == PrimitiveType::i32) { ctx_->set_arg(i, (int32)d); - } else if (dt == DataType::i64) { + } else if (dt == PrimitiveType::i64) { ctx_->set_arg(i, (int64)d); - } else if (dt == DataType::i8) { + } else if (dt == PrimitiveType::i8) { ctx_->set_arg(i, (int8)d); - } else if (dt == DataType::i16) { + } else if (dt == PrimitiveType::i16) { ctx_->set_arg(i, (int16)d); - } else if (dt == DataType::u8) { + } else if (dt == PrimitiveType::u8) { ctx_->set_arg(i, (uint8)d); - } else if (dt == DataType::u16) { + } else if (dt == PrimitiveType::u16) { ctx_->set_arg(i, (uint16)d); - } else if (dt == DataType::u32) { + } else if (dt == PrimitiveType::u32) { ctx_->set_arg(i, (uint32)d); - } else if (dt == DataType::u64) { + } else if (dt == PrimitiveType::u64) { ctx_->set_arg(i, (uint64)d); - } else if (dt == DataType::f32) { + } else if (dt == PrimitiveType::f32) { ctx_->set_arg(i, (float32)d); - } else if (dt == DataType::f64) { + } else if (dt == PrimitiveType::f64) { ctx_->set_arg(i, (float64)d); } else { TI_NOT_IMPLEMENTED @@ -247,25 +247,25 @@ Context &Kernel::LaunchContextBuilder::get_context() { float64 Kernel::get_ret_float(int i) { auto dt = rets[i].dt; - if (dt == DataType::f32) { + if (dt == PrimitiveType::f32) { return (float64)get_current_program().fetch_result(i); - } else if (dt == DataType::f64) { + } else if (dt == PrimitiveType::f64) { return (float64)get_current_program().fetch_result(i); - } else if (dt == DataType::i32) { + } else if (dt == PrimitiveType::i32) { return (float64)get_current_program().fetch_result(i); - } else if (dt == DataType::i64) { + } else if (dt == PrimitiveType::i64) { return (float64)get_current_program().fetch_result(i); - } else if (dt == DataType::i8) { + } else if (dt == PrimitiveType::i8) { return (float64)get_current_program().fetch_result(i); - } else if (dt == DataType::i16) { + } else if (dt == PrimitiveType::i16) { return (float64)get_current_program().fetch_result(i); - } else if (dt == DataType::u8) { + } else if (dt == PrimitiveType::u8) { return (float64)get_current_program().fetch_result(i); - } else if (dt == DataType::u16) { + } else if (dt == PrimitiveType::u16) { return (float64)get_current_program().fetch_result(i); - } else if (dt == DataType::u32) { + } else if (dt == PrimitiveType::u32) { return (float64)get_current_program().fetch_result(i); - } else if (dt == DataType::u64) { + } else if (dt == PrimitiveType::u64) { return (float64)get_current_program().fetch_result(i); } else { TI_NOT_IMPLEMENTED @@ -274,25 +274,25 @@ float64 Kernel::get_ret_float(int i) { int64 Kernel::get_ret_int(int i) { auto dt = rets[i].dt; - if (dt == DataType::i32) { + if (dt == PrimitiveType::i32) { return (int64)get_current_program().fetch_result(i); - } else if (dt == DataType::i64) { + } else if (dt == PrimitiveType::i64) { return (int64)get_current_program().fetch_result(i); - } else if (dt == DataType::i8) { + } else if (dt == PrimitiveType::i8) { return (int64)get_current_program().fetch_result(i); - } else if (dt == DataType::i16) { + } else if (dt == PrimitiveType::i16) { return (int64)get_current_program().fetch_result(i); - } else if (dt == DataType::u8) { + } else if (dt == PrimitiveType::u8) { return (int64)get_current_program().fetch_result(i); - } else if (dt == DataType::u16) { + } else if (dt == PrimitiveType::u16) { return (int64)get_current_program().fetch_result(i); - } else if (dt == DataType::u32) { + } else if (dt == PrimitiveType::u32) { return (int64)get_current_program().fetch_result(i); - } else if (dt == DataType::u64) { + } else if (dt == PrimitiveType::u64) { return (int64)get_current_program().fetch_result(i); - } else if (dt == DataType::f32) { + } else if (dt == PrimitiveType::f32) { return (int64)get_current_program().fetch_result(i); - } else if (dt == DataType::f64) { + } else if (dt == PrimitiveType::f64) { return (int64)get_current_program().fetch_result(i); } else { TI_NOT_IMPLEMENTED diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index 42f6c95480142..919ae5a277249 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -28,7 +28,7 @@ class Kernel { bool is_nparray; std::size_t size; - Arg(DataType dt = DataType::unknown, + Arg(DataType dt = PrimitiveType::unknown, bool is_nparray = false, std::size_t size = 0) : dt(dt), is_nparray(is_nparray), size(size) { @@ -38,7 +38,7 @@ class Kernel { struct Ret { DataType dt; - explicit Ret(DataType dt = DataType::unknown) : dt(dt) { + explicit Ret(DataType dt = PrimitiveType::unknown) : dt(dt) { } }; diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 244543832ef49..f4782fb802a60 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -609,7 +609,7 @@ Kernel &Program::get_snode_reader(SNode *snode) { auto &ker = kernel([snode] { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { - indices.push_back(Expr::make(i, DataType::i32)); + indices.push_back(Expr::make(i, PrimitiveType::i32)); } auto ret = Stmt::make( load_if_ptr((snode->expr)[indices]), snode->dt); @@ -619,7 +619,7 @@ Kernel &Program::get_snode_reader(SNode *snode) { ker.name = kernel_name; ker.is_accessor = true; for (int i = 0; i < snode->num_active_indices; i++) - ker.insert_arg(DataType::i32, false); + ker.insert_arg(PrimitiveType::i32, false); ker.insert_ret(snode->dt); return ker; } @@ -630,7 +630,7 @@ Kernel &Program::get_snode_writer(SNode *snode) { auto &ker = kernel([&] { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { - indices.push_back(Expr::make(i, DataType::i32)); + indices.push_back(Expr::make(i, PrimitiveType::i32)); } (snode->expr)[indices] = Expr::make(snode->num_active_indices, snode->dt); @@ -639,7 +639,7 @@ Kernel &Program::get_snode_writer(SNode *snode) { ker.name = kernel_name; ker.is_accessor = true; for (int i = 0; i < snode->num_active_indices; i++) - ker.insert_arg(DataType::i32, false); + ker.insert_arg(PrimitiveType::i32, false); ker.insert_arg(snode->dt, false); return ker; } diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 4a3d6404cd467..270f06fb86986 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -506,7 +506,7 @@ void export_lang(py::module &m) { auto var = Expr(std::make_shared()); current_ast_builder().insert(std::make_unique( std::static_pointer_cast(var.expr)->id, - DataType::unknown)); + PrimitiveType::unknown)); return var; }); m.def("expr_assign", expr_assign); @@ -548,8 +548,9 @@ void export_lang(py::module &m) { unary.export_values(); m.def("make_unary_op_expr", Expr::make); -#define PER_TYPE(x) \ - m.attr(("DataType_" + data_type_name(DataType::x)).c_str()) = DataType::x; +#define PER_TYPE(x) \ + m.attr(("DataType_" + data_type_name(PrimitiveType::x)).c_str()) = \ + PrimitiveType::x; #include "taichi/inc/data_type.inc.h" #undef PER_TYPE diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 9d83f91114a1b..16b43808f6bca 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -106,10 +106,10 @@ class AlgSimp : public BasicStmtVisitor { // a / const -> a * (1 / const) auto reciprocal = Stmt::make_typed( LaneAttribute(rhs->ret_type.data_type)); - if (rhs->ret_type.data_type == DataType::f64) { + if (rhs->ret_type.data_type == PrimitiveType::f64) { reciprocal->val[0].val_float64() = (float64)1.0 / rhs->val[0].val_float64(); - } else if (rhs->ret_type.data_type == DataType::f32) { + } else if (rhs->ret_type.data_type == PrimitiveType::f32) { reciprocal->val[0].val_float32() = (float32)1.0 / rhs->val[0].val_float32(); } else { diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index 8c44a6dddb108..8c34c860b3b5a 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -76,8 +76,8 @@ class ConstantFold : public BasicStmtVisitor { // ConstStmt of `bad` types like `i8` is not supported by LLVM. // Discussion: // https://github.com/taichi-dev/taichi/pull/839#issuecomment-625902727 - if (dt == DataType::i32 || dt == DataType::f32 || dt == DataType::i64 || - dt == DataType::f64) + if (dt == PrimitiveType::i32 || dt == PrimitiveType::f32 || + dt == PrimitiveType::i64 || dt == PrimitiveType::f64) return true; else return false; diff --git a/taichi/transforms/demote_dense_struct_fors.cpp b/taichi/transforms/demote_dense_struct_fors.cpp index 265080331a465..7b4c358624c75 100644 --- a/taichi/transforms/demote_dense_struct_fors.cpp +++ b/taichi/transforms/demote_dense_struct_fors.cpp @@ -87,7 +87,7 @@ void convert_to_range_for(OffloadedStmt *offloaded) { } for (int i = 0; i < num_loop_vars; i++) { - auto alloca = body_header.push_back(DataType::i32); + auto alloca = body_header.push_back(PrimitiveType::i32); body_header.push_back(alloca, new_loop_vars[i]); irpass::replace_statements_with( body.get(), diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index bd6f4baeb8894..32474778e4e61 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -81,8 +81,8 @@ class LowerAST : public IRVisitor { auto new_if = std::make_unique(stmt->condition->stmt); - new_if->true_mask = fctx.push_back(DataType::i32); - new_if->false_mask = fctx.push_back(DataType::i32); + new_if->true_mask = fctx.push_back(PrimitiveType::i32); + new_if->false_mask = fctx.push_back(PrimitiveType::i32); fctx.push_back(new_if->true_mask, stmt->condition->stmt); auto lnot_stmt_ptr = fctx.push_back(UnaryOpType::logic_not, @@ -154,7 +154,7 @@ class LowerAST : public IRVisitor { auto cond_stmt = fctx.back_stmt(); auto &&new_while = std::make_unique(std::move(stmt->body)); - auto mask = std::make_unique(DataType::i32); + auto mask = std::make_unique(PrimitiveType::i32); new_while->mask = mask.get(); auto &stmts = new_while->body; stmts->insert(std::move(fctx.stmts), /*location=*/0); @@ -162,7 +162,7 @@ class LowerAST : public IRVisitor { stmts->insert( std::make_unique(new_while->mask, cond_stmt), fctx.stmts.size()); - stmt->insert_before_me(std::make_unique(DataType::i32)); + stmt->insert_before_me(std::make_unique(PrimitiveType::i32)); auto &&const_stmt = std::make_unique(TypedConstant((int32)0xFFFFFFFF)); auto const_stmt_ptr = const_stmt.get(); @@ -216,7 +216,7 @@ class LowerAST : public IRVisitor { } else { // transform into a structure as // i = begin; while (1) { if (i >= end) break; original body; i += 1; } - fctx.push_back(DataType::i32); + fctx.push_back(PrimitiveType::i32); auto loop_var = fctx.back_stmt(); stmt->parent->local_var_to_stmt[stmt->loop_var_id[0]] = loop_var; fctx.push_back(loop_var, begin->stmt); @@ -229,7 +229,7 @@ class LowerAST : public IRVisitor { BinaryOpType::cmp_lt, loop_var_load_stmt, end->stmt); auto &&new_while = std::make_unique(std::move(stmt->body)); - auto mask = std::make_unique(DataType::i32); + auto mask = std::make_unique(PrimitiveType::i32); new_while->mask = mask.get(); auto &stmts = new_while->body; for (int i = 0; i < (int)load_and_compare.size(); i++) { @@ -251,7 +251,8 @@ class LowerAST : public IRVisitor { std::make_unique(new_while->mask, cond_stmt), load_and_compare.size()); - stmt->insert_before_me(std::make_unique(DataType::i32)); + stmt->insert_before_me( + std::make_unique(PrimitiveType::i32)); auto &&const_stmt = std::make_unique(TypedConstant((int32)0xFFFFFFFF)); auto const_stmt_ptr = const_stmt.get(); @@ -323,7 +324,7 @@ class LowerAST : public IRVisitor { auto fctx = make_flatten_ctx(); expr->flatten(&fctx); const auto dt = stmt->element_type(); - TI_ASSERT(dt != DataType::unknown); + TI_ASSERT(dt != PrimitiveType::unknown); fctx.push_back(fctx.back_stmt(), dt); stmt->parent->replace_with(stmt, std::move(fctx.stmts)); throw IRModified(); diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 9c45f9e6e442e..61028ba4c7ddc 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -183,7 +183,7 @@ class BasicBlockSimplify : public IRVisitor { if (k == num_loop_vars - 1) { auto load = stmt->insert_before_me( Stmt::make(current_struct_for, k)); - load->ret_type.data_type = DataType::i32; + load->ret_type.data_type = PrimitiveType::i32; stmt->input = load; int64 bound = 1LL << stmt->bit_end; auto offset = (((int64)diff.low % bound + bound) % bound) & @@ -204,7 +204,7 @@ class BasicBlockSimplify : public IRVisitor { if (offset != 0) { auto offset_const = stmt->insert_before_me( Stmt::make(LaneAttribute( - TypedConstant(DataType::i32, offset)))); + TypedConstant(PrimitiveType::i32, offset)))); auto sum = stmt->insert_before_me(Stmt::make( BinaryOpType::add, load, offset_const)); stmt->input = sum; @@ -214,12 +214,12 @@ class BasicBlockSimplify : public IRVisitor { // insert constant auto load = stmt->insert_before_me( Stmt::make(current_struct_for, k)); - load->ret_type.data_type = DataType::i32; + load->ret_type.data_type = PrimitiveType::i32; auto constant = stmt->insert_before_me( Stmt::make(TypedConstant(diff.low))); auto add = stmt->insert_before_me( Stmt::make(BinaryOpType::add, load, constant)); - add->ret_type.data_type = DataType::i32; + add->ret_type.data_type = PrimitiveType::i32; stmt->input = add; } stmt->simplified = true; @@ -294,8 +294,8 @@ class BasicBlockSimplify : public IRVisitor { // compute offset... for (int i = 0; i < (int)snode->ch.size(); i++) { TI_ASSERT(snode->ch[i]->type == SNodeType::place); - TI_ASSERT(snode->ch[i]->dt == DataType::i32 || - snode->ch[i]->dt == DataType::f32); + TI_ASSERT(snode->ch[i]->dt == PrimitiveType::i32 || + snode->ch[i]->dt == PrimitiveType::f32); } auto offset_stmt = stmt->insert_after_me(Stmt::make( diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index b749c1af12072..fa07cbe40d14c 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -38,9 +38,9 @@ class TypeCheck : public IRVisitor { } void visit(IfStmt *if_stmt) { - // TODO: use DataType::u1 when it's supported + // TODO: use PrimitiveType::u1 when it's supported TI_ASSERT_INFO( - if_stmt->cond->ret_type.data_type == DataType::i32, + if_stmt->cond->ret_type.data_type == PrimitiveType::i32, "`if` conditions must be of type int32, consider using `if x != 0:` " "instead of `if x:` for float values."); if (if_stmt->true_statements) @@ -69,7 +69,7 @@ class TypeCheck : public IRVisitor { stmt->val = insert_type_cast_before(stmt, stmt->val, stmt->dest->ret_type.data_type); } - if (stmt->element_type() == DataType::unknown) { + if (stmt->element_type() == PrimitiveType::unknown) { stmt->ret_type = stmt->dest->ret_type; } stmt->ret_type.set_is_pointer(false); @@ -82,7 +82,7 @@ class TypeCheck : public IRVisitor { } void visit(LocalStoreStmt *stmt) { - if (stmt->ptr->ret_type.data_type == DataType::unknown) { + if (stmt->ptr->ret_type.data_type == PrimitiveType::unknown) { // Infer data type for alloca stmt->ptr->ret_type = stmt->data->ret_type; } @@ -110,11 +110,11 @@ class TypeCheck : public IRVisitor { } void visit(SNodeOpStmt *stmt) { - stmt->ret_type = VectorType(1, DataType::i32); + stmt->ret_type = VectorType(1, PrimitiveType::i32); } void visit(ExternalTensorShapeAlongAxisStmt *stmt) { - stmt->ret_type = VectorType(1, DataType::i32); + stmt->ret_type = VectorType(1, PrimitiveType::i32); } void visit(GlobalPtrStmt *stmt) { @@ -138,7 +138,7 @@ class TypeCheck : public IRVisitor { "[{}] Field index {} not integral, casting into int32 implicitly", stmt->name(), i); stmt->indices[i] = - insert_type_cast_before(stmt, stmt->indices[i], DataType::i32); + insert_type_cast_before(stmt, stmt->indices[i], PrimitiveType::i32); } TI_ASSERT(stmt->indices[i]->ret_type.width == stmt->snodes.size()); } @@ -161,8 +161,8 @@ class TypeCheck : public IRVisitor { } void visit(RangeForStmt *stmt) { - mark_as_if_const(stmt->begin, VectorType(1, DataType::i32)); - mark_as_if_const(stmt->end, VectorType(1, DataType::i32)); + mark_as_if_const(stmt->begin, VectorType(1, PrimitiveType::i32)); + mark_as_if_const(stmt->end, VectorType(1, PrimitiveType::i32)); stmt->body->accept(this); } @@ -239,8 +239,8 @@ class TypeCheck : public IRVisitor { TI_WARN("Compilation stopped due to type mismatch."); throw std::runtime_error("Binary operator type mismatch"); }; - if (stmt->lhs->ret_type.data_type == DataType::unknown && - stmt->rhs->ret_type.data_type == DataType::unknown) + if (stmt->lhs->ret_type.data_type == PrimitiveType::unknown && + stmt->rhs->ret_type.data_type == PrimitiveType::unknown) error(); // lower truediv into div @@ -273,8 +273,10 @@ class TypeCheck : public IRVisitor { bool matching = true; matching = matching && (stmt->lhs->ret_type.width == stmt->rhs->ret_type.width); - matching = matching && (stmt->lhs->ret_type.data_type != DataType::unknown); - matching = matching && (stmt->rhs->ret_type.data_type != DataType::unknown); + matching = + matching && (stmt->lhs->ret_type.data_type != PrimitiveType::unknown); + matching = + matching && (stmt->rhs->ret_type.data_type != PrimitiveType::unknown); matching = matching && (stmt->lhs->ret_type == stmt->rhs->ret_type); if (!matching) { error(); @@ -285,7 +287,8 @@ class TypeCheck : public IRVisitor { } } if (is_comparison(stmt->op_type)) { - stmt->ret_type = VectorType(stmt->lhs->ret_type.width, DataType::i32); + stmt->ret_type = + VectorType(stmt->lhs->ret_type.width, PrimitiveType::i32); } else { stmt->ret_type = stmt->lhs->ret_type; } @@ -295,7 +298,7 @@ class TypeCheck : public IRVisitor { if (stmt->op_type == TernaryOpType::select) { auto ret_type = promoted_type(stmt->op2->ret_type.data_type, stmt->op3->ret_type.data_type); - TI_ASSERT(stmt->op1->ret_type.data_type == DataType::i32) + TI_ASSERT(stmt->op1->ret_type.data_type == PrimitiveType::i32) TI_ASSERT(stmt->op1->ret_type.width == stmt->op2->ret_type.width); TI_ASSERT(stmt->op2->ret_type.width == stmt->op3->ret_type.width); if (ret_type != stmt->op2->ret_type.data_type) { @@ -327,7 +330,7 @@ class TypeCheck : public IRVisitor { // TODO: Maybe have a type_inference() pass, which takes in the args/rets // defined by the kernel. After that, type_check() pass will purely do // verification, without modifying any types. - TI_ASSERT(rt.data_type != DataType::unknown); + TI_ASSERT(rt.data_type != PrimitiveType::unknown); TI_ASSERT(rt.width == 1); } @@ -345,27 +348,27 @@ class TypeCheck : public IRVisitor { } void visit(LoopIndexStmt *stmt) { - stmt->ret_type = VectorType(1, DataType::i32); + stmt->ret_type = VectorType(1, PrimitiveType::i32); } void visit(LoopLinearIndexStmt *stmt) { - stmt->ret_type = VectorType(1, DataType::i32); + stmt->ret_type = VectorType(1, PrimitiveType::i32); } void visit(BlockCornerIndexStmt *stmt) { - stmt->ret_type = VectorType(1, DataType::i32); + stmt->ret_type = VectorType(1, PrimitiveType::i32); } void visit(BlockDimStmt *stmt) { - stmt->ret_type = VectorType(1, DataType::i32); + stmt->ret_type = VectorType(1, PrimitiveType::i32); } void visit(GetRootStmt *stmt) { - stmt->ret_type = VectorType(1, DataType::gen, true); + stmt->ret_type = VectorType(1, PrimitiveType::gen, true); } void visit(SNodeLookupStmt *stmt) { - stmt->ret_type = VectorType(1, DataType::gen, true); + stmt->ret_type = VectorType(1, PrimitiveType::gen, true); } void visit(GetChStmt *stmt) { @@ -382,11 +385,11 @@ class TypeCheck : public IRVisitor { } void visit(LinearizeStmt *stmt) { - stmt->ret_type.data_type = DataType::i32; + stmt->ret_type.data_type = PrimitiveType::i32; } void visit(IntegerOffsetStmt *stmt) { - stmt->ret_type.data_type = DataType::i32; + stmt->ret_type.data_type = PrimitiveType::i32; } void visit(StackAllocaStmt *stmt) { diff --git a/tests/cpp/test_alg_simp.cpp b/tests/cpp/test_alg_simp.cpp index 049e5a8ac0639..7ff940e0f16db 100644 --- a/tests/cpp/test_alg_simp.cpp +++ b/tests/cpp/test_alg_simp.cpp @@ -16,14 +16,14 @@ TI_TEST("alg_simp") { std::make_unique(get_current_program(), func, "fake_kernel"); block->kernel = kernel.get(); - auto global_load_addr = - block->push_back(0, VectorType(1, DataType::i32)); + auto global_load_addr = block->push_back( + 0, VectorType(1, PrimitiveType::i32)); auto global_load = block->push_back(global_load_addr); auto zero = block->push_back(TypedConstant(0)); auto add = block->push_back(BinaryOpType::add, global_load, zero); - auto global_store_addr = - block->push_back(4, VectorType(1, DataType::i32)); + auto global_store_addr = block->push_back( + 4, VectorType(1, PrimitiveType::i32)); auto global_store = block->push_back(global_store_addr, add); @@ -51,8 +51,8 @@ TI_TEST("alg_simp") { std::make_unique(get_current_program(), func, "fake_kernel"); block->kernel = kernel.get(); - auto global_load_addr = - block->push_back(0, VectorType(1, DataType::f32)); + auto global_load_addr = block->push_back( + 0, VectorType(1, PrimitiveType::f32)); auto global_load = block->push_back(global_load_addr); auto one = block->push_back(TypedConstant(1.0f)); auto mul1 = @@ -61,8 +61,8 @@ TI_TEST("alg_simp") { auto zero = block->push_back(TypedConstant(0.0f)); auto div = block->push_back(BinaryOpType::div, zero, one); auto sub = block->push_back(BinaryOpType::sub, mul2, div); - auto global_store_addr = - block->push_back(4, VectorType(1, DataType::f32)); + auto global_store_addr = block->push_back( + 4, VectorType(1, PrimitiveType::f32)); auto global_store = block->push_back(global_store_addr, sub); @@ -89,16 +89,16 @@ TI_TEST("alg_simp") { std::make_unique(get_current_program(), func, "fake_kernel"); block->kernel = kernel.get(); - auto global_load_addr = - block->push_back(0, VectorType(1, DataType::i32)); + auto global_load_addr = block->push_back( + 0, VectorType(1, PrimitiveType::i32)); auto global_load = block->push_back(global_load_addr); auto zero = block->push_back(TypedConstant(0)); auto mul = block->push_back(BinaryOpType::mul, global_load, zero); auto one = block->push_back(TypedConstant(1)); auto add = block->push_back(BinaryOpType::add, mul, one); - auto global_store_addr = - block->push_back(4, VectorType(1, DataType::i32)); + auto global_store_addr = block->push_back( + 4, VectorType(1, PrimitiveType::i32)); auto global_store = block->push_back(global_store_addr, add); @@ -117,15 +117,15 @@ TI_TEST("alg_simp") { block = std::make_unique(); block->kernel = kernel.get(); - global_load_addr = - block->push_back(8, VectorType(1, DataType::f32)); + global_load_addr = block->push_back( + 8, VectorType(1, PrimitiveType::f32)); global_load = block->push_back(global_load_addr); zero = block->push_back(TypedConstant(0)); mul = block->push_back(BinaryOpType::mul, global_load, zero); one = block->push_back(TypedConstant(1)); add = block->push_back(BinaryOpType::add, mul, one); - global_store_addr = - block->push_back(12, VectorType(1, DataType::f32)); + global_store_addr = block->push_back( + 12, VectorType(1, PrimitiveType::f32)); global_store = block->push_back(global_store_addr, add); irpass::type_check(block.get()); // insert 2 casts @@ -151,14 +151,14 @@ TI_TEST("alg_simp") { auto block = std::make_unique(); - auto global_load_addr = - block->push_back(0, VectorType(1, DataType::i32)); + auto global_load_addr = block->push_back( + 0, VectorType(1, PrimitiveType::i32)); auto global_load = block->push_back(global_load_addr); auto minus_one = block->push_back(TypedConstant(-1)); auto and_result = block->push_back(BinaryOpType::bit_and, minus_one, global_load); - auto global_store_addr = - block->push_back(4, VectorType(1, DataType::i32)); + auto global_store_addr = block->push_back( + 4, VectorType(1, PrimitiveType::i32)); auto global_store = block->push_back(global_store_addr, and_result); diff --git a/tests/cpp/test_same_statements.cpp b/tests/cpp/test_same_statements.cpp index 7d1b984463e3d..a802f784c54a0 100644 --- a/tests/cpp/test_same_statements.cpp +++ b/tests/cpp/test_same_statements.cpp @@ -9,11 +9,11 @@ TI_TEST("same_statements") { SECTION("test_same_block") { auto block = std::make_unique(); - auto global_load_addr = - block->push_back(0, VectorType(1, DataType::i32)); + auto global_load_addr = block->push_back( + 0, VectorType(1, PrimitiveType::i32)); auto global_load = block->push_back(global_load_addr); - auto global_store_addr = - block->push_back(4, VectorType(1, DataType::i32)); + auto global_store_addr = block->push_back( + 4, VectorType(1, PrimitiveType::i32)); auto one = block->push_back(TypedConstant(1)); auto if_stmt = block->push_back(one)->as();