From ec240d149be5bd685a2f62bb58be1198d11922b6 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Mon, 12 Oct 2020 01:42:27 -0400 Subject: [PATCH 1/2] [refactor] VectorType -> LegacyVectorType --- taichi/backends/metal/codegen_metal.cpp | 2 +- taichi/codegen/codegen_llvm.cpp | 4 ++-- taichi/ir/frontend_ir.h | 4 ++-- taichi/ir/ir.cpp | 6 ++--- taichi/ir/ir.h | 12 +++++----- taichi/ir/statements.h | 16 ++++++------- taichi/transforms/make_block_local.cpp | 6 ++--- taichi/transforms/make_thread_local.cpp | 6 ++--- taichi/transforms/offload.cpp | 4 ++-- taichi/transforms/type_check.cpp | 30 ++++++++++++------------- tests/cpp/test_alg_simp.cpp | 20 ++++++++--------- tests/cpp/test_same_statements.cpp | 4 ++-- 12 files changed, 57 insertions(+), 57 deletions(-) diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index aa49f1a1e05e1..cf9f5570c9480 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -1039,7 +1039,7 @@ class KernelCodegen : public IRVisitor { std::string inject_load_global_tmp(int offset, DataType dt = PrimitiveType::i32) { - const auto vt = VectorType(/*width=*/1, dt); + const auto vt = LegacyVectorType(/*width=*/1, dt); auto gtmp = Stmt::make(offset, vt); gtmp->accept(this); auto gload = Stmt::make(gtmp.get()); diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index dc2337ecbd779..346a6b5716e17 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1297,7 +1297,7 @@ std::tuple CodeGenLLVM::get_range_for_bounds( begin = tlctx->get_constant(stmt->begin_value); } else { auto begin_stmt = Stmt::make( - stmt->begin_offset, VectorType(1, PrimitiveType::i32)); + stmt->begin_offset, LegacyVectorType(1, PrimitiveType::i32)); begin_stmt->accept(this); begin = builder->CreateLoad(llvm_val[begin_stmt.get()]); } @@ -1305,7 +1305,7 @@ std::tuple CodeGenLLVM::get_range_for_bounds( end = tlctx->get_constant(stmt->end_value); } else { auto end_stmt = Stmt::make( - stmt->end_offset, VectorType(1, PrimitiveType::i32)); + stmt->end_offset, LegacyVectorType(1, PrimitiveType::i32)); end_stmt->accept(this); end = builder->CreateLoad(llvm_val[end_stmt.get()]); } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index f773e593065f0..5ea6738b17c95 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -16,7 +16,7 @@ class FrontendAllocaStmt : public Stmt { Identifier ident; FrontendAllocaStmt(const Identifier &lhs, DataType type) : ident(lhs) { - ret_type = VectorType(1, type); + ret_type = LegacyVectorType(1, type); } TI_DEFINE_ACCEPT @@ -203,7 +203,7 @@ class FrontendKernelReturnStmt : public Stmt { Expr value; FrontendKernelReturnStmt(const Expr &value, DataType dt) : value(value) { - ret_type = VectorType(1, dt); + ret_type = LegacyVectorType(1, dt); } bool is_container_statement() const override { diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index e01c0f40b6356..cb27783d0a546 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -21,7 +21,7 @@ IRBuilder ¤t_ast_builder() { return context->builder(); } -std::string VectorType::pointer_suffix() const { +std::string LegacyVectorType::pointer_suffix() const { if (is_pointer()) { return "*"; } else { @@ -29,11 +29,11 @@ std::string VectorType::pointer_suffix() const { } } -std::string VectorType::element_type_name() const { +std::string LegacyVectorType::element_type_name() const { return fmt::format("{}{}", data_type_short_name(data_type), pointer_suffix()); } -std::string VectorType::str() const { +std::string LegacyVectorType::str() const { auto ename = element_type_name(); return fmt::format("{:4}x{}", ename, width); } diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 47b7588004be5..12c4d8c2d04a9 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -33,7 +33,7 @@ using ScratchPadOptions = std::vector>; IRBuilder ¤t_ast_builder(); -struct VectorType { +struct LegacyVectorType { private: bool _is_pointer; @@ -41,19 +41,19 @@ struct VectorType { int width; DataType data_type; - VectorType(int width, DataType data_type, bool is_pointer = false) + LegacyVectorType(int width, DataType data_type, bool is_pointer = false) : _is_pointer(is_pointer), width(width), data_type(data_type) { } - VectorType() + LegacyVectorType() : _is_pointer(false), width(1), data_type(PrimitiveType::unknown) { } - bool operator==(const VectorType &o) const { + bool operator==(const LegacyVectorType &o) const { return width == o.width && data_type == o.data_type; } - bool operator!=(const VectorType &o) const { + bool operator!=(const LegacyVectorType &o) const { return !(*this == o); } @@ -531,7 +531,7 @@ class Stmt : public IRNode { bool fields_registered; std::string tb; bool is_ptr; - VectorType ret_type; + LegacyVectorType ret_type; Stmt(); Stmt(const Stmt &stmt); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 3fb1b4399c2dd..0b5309b917a64 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -8,12 +8,12 @@ TLANG_NAMESPACE_BEGIN class AllocaStmt : public Stmt { public: AllocaStmt(DataType type) { - ret_type = VectorType(1, type); + ret_type = LegacyVectorType(1, type); TI_STMT_REG_FIELDS; } AllocaStmt(int width, DataType type) { - ret_type = VectorType(width, type); + ret_type = LegacyVectorType(width, type); TI_STMT_REG_FIELDS; } @@ -103,7 +103,7 @@ class ArgLoadStmt : public Stmt { int arg_id; ArgLoadStmt(int arg_id, DataType dt, bool is_ptr = false) : arg_id(arg_id) { - this->ret_type = VectorType(1, dt); + this->ret_type = LegacyVectorType(1, dt); this->is_ptr = is_ptr; TI_STMT_REG_FIELDS; } @@ -613,7 +613,7 @@ class KernelReturnStmt : public Stmt { Stmt *value; KernelReturnStmt(Stmt *value, DataType dt) : value(value) { - this->ret_type = VectorType(1, dt); + this->ret_type = LegacyVectorType(1, dt); TI_STMT_REG_FIELDS; } @@ -938,7 +938,7 @@ class GlobalTemporaryStmt : public Stmt { public: std::size_t offset; - GlobalTemporaryStmt(std::size_t offset, VectorType ret_type) + GlobalTemporaryStmt(std::size_t offset, LegacyVectorType ret_type) : offset(offset) { this->ret_type = ret_type; TI_STMT_REG_FIELDS; @@ -956,7 +956,7 @@ class ThreadLocalPtrStmt : public Stmt { public: std::size_t offset; - ThreadLocalPtrStmt(std::size_t offset, VectorType ret_type) : offset(offset) { + ThreadLocalPtrStmt(std::size_t offset, LegacyVectorType ret_type) : offset(offset) { this->ret_type = ret_type; TI_STMT_REG_FIELDS; } @@ -973,7 +973,7 @@ class BlockLocalPtrStmt : public Stmt { public: Stmt *offset; - BlockLocalPtrStmt(Stmt *offset, VectorType ret_type) : offset(offset) { + BlockLocalPtrStmt(Stmt *offset, LegacyVectorType ret_type) : offset(offset) { this->ret_type = ret_type; TI_STMT_REG_FIELDS; } @@ -1004,7 +1004,7 @@ class InternalFuncStmt : public Stmt { std::string func_name; InternalFuncStmt(const std::string &func_name) : func_name(func_name) { - this->ret_type = VectorType(1, PrimitiveType::i32); + this->ret_type = LegacyVectorType(1, PrimitiveType::i32); TI_STMT_REG_FIELDS; } diff --git a/taichi/transforms/make_block_local.cpp b/taichi/transforms/make_block_local.cpp index 7e87690c25170..759de772b027c 100644 --- a/taichi/transforms/make_block_local.cpp +++ b/taichi/transforms/make_block_local.cpp @@ -179,7 +179,7 @@ void make_block_local_offload(OffloadedStmt *offload) { TypedConstant(data_type, 0)); } auto bls_ptr = element_block->push_back( - bls_element_offset_bytes, VectorType(1, data_type)); + bls_element_offset_bytes, LegacyVectorType(1, data_type)); element_block->push_back(bls_ptr, value); }); } @@ -268,7 +268,7 @@ void make_block_local_offload(OffloadedStmt *offload) { bls.push_back(TypedConstant((int32)bls_offset))); bls.push_back(bls_element_offset, - VectorType(1, data_type)); + LegacyVectorType(1, data_type)); global_ptr->replace_with(std::move(bls)); } } @@ -282,7 +282,7 @@ void make_block_local_offload(OffloadedStmt *offload) { Stmt *bls_element_offset_bytes) { // Store/accumulate from BLS to global auto bls_ptr = element_block->push_back( - bls_element_offset_bytes, VectorType(1, data_type)); + bls_element_offset_bytes, LegacyVectorType(1, data_type)); auto bls_val = element_block->push_back(bls_ptr); auto global_pointer = diff --git a/taichi/transforms/make_thread_local.cpp b/taichi/transforms/make_thread_local.cpp index ec6c0b72b662d..cb2fe181edb5b 100644 --- a/taichi/transforms/make_thread_local.cpp +++ b/taichi/transforms/make_thread_local.cpp @@ -127,7 +127,7 @@ void make_thread_local_offload(OffloadedStmt *offload) { tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size; auto tls_ptr = offload->tls_prologue->push_back( - tls_offset, VectorType(1, data_type)); + tls_offset, LegacyVectorType(1, data_type)); auto zero = offload->tls_prologue->insert( std::make_unique(TypedConstant(data_type, 0)), -1); @@ -140,7 +140,7 @@ void make_thread_local_offload(OffloadedStmt *offload) { // Make loop body accumulate to TLS ptr instead of global ptr { auto tls_ptr = offload->body->insert( - Stmt::make(tls_offset, VectorType(1, data_type)), + Stmt::make(tls_offset, LegacyVectorType(1, data_type)), 0); dest->replace_with(tls_ptr); } @@ -153,7 +153,7 @@ void make_thread_local_offload(OffloadedStmt *offload) { offload->tls_epilogue->parent_stmt = offload; } auto tls_ptr = offload->tls_epilogue->push_back( - tls_offset, VectorType(1, data_type)); + tls_offset, LegacyVectorType(1, data_type)); // TODO: do not use global load from TLS. auto tls_load = offload->tls_epilogue->push_back(tls_ptr); auto global_ptr = offload->tls_epilogue->insert( diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 2c809b70a461c..04368916d5ea9 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -257,7 +257,7 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { global_offset = 0; } - std::size_t allocate_global(VectorType type) { + std::size_t allocate_global(LegacyVectorType type) { TI_ASSERT(type.width == 1); auto ret = global_offset; global_offset += data_type_size(type.data_type); @@ -563,7 +563,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { StmtToOffsetMap local_to_global_offset; std::unordered_map stmt_to_offloaded; OffloadedRanges *const offloaded_ranges_; - std::unordered_map local_to_global_vector_type; + std::unordered_map local_to_global_vector_type; }; void insert_gc(IRNode *root) { diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index fa07cbe40d14c..44fa1df9bb4dd 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -24,7 +24,7 @@ class TypeCheck : public IRVisitor { allow_undefined_visitor = true; } - static void mark_as_if_const(Stmt *stmt, VectorType t) { + static void mark_as_if_const(Stmt *stmt, LegacyVectorType t) { if (stmt->is()) { stmt->ret_type = t; } @@ -110,11 +110,11 @@ class TypeCheck : public IRVisitor { } void visit(SNodeOpStmt *stmt) { - stmt->ret_type = VectorType(1, PrimitiveType::i32); + stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32); } void visit(ExternalTensorShapeAlongAxisStmt *stmt) { - stmt->ret_type = VectorType(1, PrimitiveType::i32); + stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32); } void visit(GlobalPtrStmt *stmt) { @@ -161,8 +161,8 @@ class TypeCheck : public IRVisitor { } void visit(RangeForStmt *stmt) { - mark_as_if_const(stmt->begin, VectorType(1, PrimitiveType::i32)); - mark_as_if_const(stmt->end, VectorType(1, PrimitiveType::i32)); + mark_as_if_const(stmt->begin, LegacyVectorType(1, PrimitiveType::i32)); + mark_as_if_const(stmt->end, LegacyVectorType(1, PrimitiveType::i32)); stmt->body->accept(this); } @@ -288,7 +288,7 @@ class TypeCheck : public IRVisitor { } if (is_comparison(stmt->op_type)) { stmt->ret_type = - VectorType(stmt->lhs->ret_type.width, PrimitiveType::i32); + LegacyVectorType(stmt->lhs->ret_type.width, PrimitiveType::i32); } else { stmt->ret_type = stmt->lhs->ret_type; } @@ -309,7 +309,7 @@ class TypeCheck : public IRVisitor { auto cast_stmt = insert_type_cast_before(stmt, stmt->op3, ret_type); stmt->op3 = cast_stmt; } - stmt->ret_type = VectorType(stmt->op1->width(), ret_type); + stmt->ret_type = LegacyVectorType(stmt->op1->width(), ret_type); } else { TI_NOT_IMPLEMENTED } @@ -343,36 +343,36 @@ class TypeCheck : public IRVisitor { void visit(ExternalPtrStmt *stmt) { stmt->ret_type.set_is_pointer(true); - stmt->ret_type = VectorType(stmt->base_ptrs.size(), + stmt->ret_type = LegacyVectorType(stmt->base_ptrs.size(), stmt->base_ptrs[0]->ret_type.data_type); } void visit(LoopIndexStmt *stmt) { - stmt->ret_type = VectorType(1, PrimitiveType::i32); + stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32); } void visit(LoopLinearIndexStmt *stmt) { - stmt->ret_type = VectorType(1, PrimitiveType::i32); + stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32); } void visit(BlockCornerIndexStmt *stmt) { - stmt->ret_type = VectorType(1, PrimitiveType::i32); + stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32); } void visit(BlockDimStmt *stmt) { - stmt->ret_type = VectorType(1, PrimitiveType::i32); + stmt->ret_type = LegacyVectorType(1, PrimitiveType::i32); } void visit(GetRootStmt *stmt) { - stmt->ret_type = VectorType(1, PrimitiveType::gen, true); + stmt->ret_type = LegacyVectorType(1, PrimitiveType::gen, true); } void visit(SNodeLookupStmt *stmt) { - stmt->ret_type = VectorType(1, PrimitiveType::gen, true); + stmt->ret_type = LegacyVectorType(1, PrimitiveType::gen, true); } void visit(GetChStmt *stmt) { - stmt->ret_type = VectorType(1, stmt->output_snode->dt); + stmt->ret_type = LegacyVectorType(1, stmt->output_snode->dt); stmt->ret_type.set_is_pointer(true); } diff --git a/tests/cpp/test_alg_simp.cpp b/tests/cpp/test_alg_simp.cpp index 7ff940e0f16db..38e6e3fb41f16 100644 --- a/tests/cpp/test_alg_simp.cpp +++ b/tests/cpp/test_alg_simp.cpp @@ -17,13 +17,13 @@ TI_TEST("alg_simp") { block->kernel = kernel.get(); auto global_load_addr = block->push_back( - 0, VectorType(1, PrimitiveType::i32)); + 0, LegacyVectorType(1, PrimitiveType::i32)); auto global_load = block->push_back(global_load_addr); auto zero = block->push_back(TypedConstant(0)); auto add = block->push_back(BinaryOpType::add, global_load, zero); auto global_store_addr = block->push_back( - 4, VectorType(1, PrimitiveType::i32)); + 4, LegacyVectorType(1, PrimitiveType::i32)); auto global_store = block->push_back(global_store_addr, add); @@ -52,7 +52,7 @@ TI_TEST("alg_simp") { block->kernel = kernel.get(); auto global_load_addr = block->push_back( - 0, VectorType(1, PrimitiveType::f32)); + 0, LegacyVectorType(1, PrimitiveType::f32)); auto global_load = block->push_back(global_load_addr); auto one = block->push_back(TypedConstant(1.0f)); auto mul1 = @@ -62,7 +62,7 @@ TI_TEST("alg_simp") { auto div = block->push_back(BinaryOpType::div, zero, one); auto sub = block->push_back(BinaryOpType::sub, mul2, div); auto global_store_addr = block->push_back( - 4, VectorType(1, PrimitiveType::f32)); + 4, LegacyVectorType(1, PrimitiveType::f32)); auto global_store = block->push_back(global_store_addr, sub); @@ -90,7 +90,7 @@ TI_TEST("alg_simp") { block->kernel = kernel.get(); auto global_load_addr = block->push_back( - 0, VectorType(1, PrimitiveType::i32)); + 0, LegacyVectorType(1, PrimitiveType::i32)); auto global_load = block->push_back(global_load_addr); auto zero = block->push_back(TypedConstant(0)); auto mul = @@ -98,7 +98,7 @@ TI_TEST("alg_simp") { auto one = block->push_back(TypedConstant(1)); auto add = block->push_back(BinaryOpType::add, mul, one); auto global_store_addr = block->push_back( - 4, VectorType(1, PrimitiveType::i32)); + 4, LegacyVectorType(1, PrimitiveType::i32)); auto global_store = block->push_back(global_store_addr, add); @@ -118,14 +118,14 @@ TI_TEST("alg_simp") { block->kernel = kernel.get(); global_load_addr = block->push_back( - 8, VectorType(1, PrimitiveType::f32)); + 8, LegacyVectorType(1, PrimitiveType::f32)); global_load = block->push_back(global_load_addr); zero = block->push_back(TypedConstant(0)); mul = block->push_back(BinaryOpType::mul, global_load, zero); one = block->push_back(TypedConstant(1)); add = block->push_back(BinaryOpType::add, mul, one); global_store_addr = block->push_back( - 12, VectorType(1, PrimitiveType::f32)); + 12, LegacyVectorType(1, PrimitiveType::f32)); global_store = block->push_back(global_store_addr, add); irpass::type_check(block.get()); // insert 2 casts @@ -152,13 +152,13 @@ TI_TEST("alg_simp") { auto block = std::make_unique(); auto global_load_addr = block->push_back( - 0, VectorType(1, PrimitiveType::i32)); + 0, LegacyVectorType(1, PrimitiveType::i32)); auto global_load = block->push_back(global_load_addr); auto minus_one = block->push_back(TypedConstant(-1)); auto and_result = block->push_back(BinaryOpType::bit_and, minus_one, global_load); auto global_store_addr = block->push_back( - 4, VectorType(1, PrimitiveType::i32)); + 4, LegacyVectorType(1, PrimitiveType::i32)); auto global_store = block->push_back(global_store_addr, and_result); diff --git a/tests/cpp/test_same_statements.cpp b/tests/cpp/test_same_statements.cpp index a802f784c54a0..a0be1c68ebdbd 100644 --- a/tests/cpp/test_same_statements.cpp +++ b/tests/cpp/test_same_statements.cpp @@ -10,10 +10,10 @@ TI_TEST("same_statements") { auto block = std::make_unique(); auto global_load_addr = block->push_back( - 0, VectorType(1, PrimitiveType::i32)); + 0, LegacyVectorType(1, PrimitiveType::i32)); auto global_load = block->push_back(global_load_addr); auto global_store_addr = block->push_back( - 4, VectorType(1, PrimitiveType::i32)); + 4, LegacyVectorType(1, PrimitiveType::i32)); auto one = block->push_back(TypedConstant(1)); auto if_stmt = block->push_back(one)->as(); From 519233211e7fea11ce3cde43c06ab955f6b13717 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Mon, 12 Oct 2020 01:43:46 -0400 Subject: [PATCH 2/2] format --- taichi/ir/statements.h | 3 ++- taichi/transforms/make_thread_local.cpp | 7 ++++--- taichi/transforms/type_check.cpp | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 0b5309b917a64..0c6924bb77049 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -956,7 +956,8 @@ class ThreadLocalPtrStmt : public Stmt { public: std::size_t offset; - ThreadLocalPtrStmt(std::size_t offset, LegacyVectorType ret_type) : offset(offset) { + ThreadLocalPtrStmt(std::size_t offset, LegacyVectorType ret_type) + : offset(offset) { this->ret_type = ret_type; TI_STMT_REG_FIELDS; } diff --git a/taichi/transforms/make_thread_local.cpp b/taichi/transforms/make_thread_local.cpp index cb2fe181edb5b..898747e90c4fe 100644 --- a/taichi/transforms/make_thread_local.cpp +++ b/taichi/transforms/make_thread_local.cpp @@ -139,9 +139,10 @@ void make_thread_local_offload(OffloadedStmt *offload) { // Step 2: // Make loop body accumulate to TLS ptr instead of global ptr { - auto tls_ptr = offload->body->insert( - Stmt::make(tls_offset, LegacyVectorType(1, data_type)), - 0); + auto tls_ptr = + offload->body->insert(Stmt::make( + tls_offset, LegacyVectorType(1, data_type)), + 0); dest->replace_with(tls_ptr); } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 44fa1df9bb4dd..f1690d2aed59a 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -344,7 +344,7 @@ class TypeCheck : public IRVisitor { void visit(ExternalPtrStmt *stmt) { stmt->ret_type.set_is_pointer(true); stmt->ret_type = LegacyVectorType(stmt->base_ptrs.size(), - stmt->base_ptrs[0]->ret_type.data_type); + stmt->base_ptrs[0]->ret_type.data_type); } void visit(LoopIndexStmt *stmt) {