diff --git a/python/taichi/testing.py b/python/taichi/testing.py index 8fe0aaabdd42b..39bb4d9cfae5b 100644 --- a/python/taichi/testing.py +++ b/python/taichi/testing.py @@ -138,8 +138,6 @@ def wrapped(*args, **kwargs): current_options[feature] = value if skip: continue - if current_options.get('dynamic_index', False): - continue ti.init(arch=req_arch, **current_options) foo(*args, **kwargs) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 192f4c55eabbd..b3d21531b7c59 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -296,6 +296,7 @@ void LocalTensorElementExpression::flatten(FlattenContext *ctx) { // ^^^^^^^^^ indices[0].set(load_if_ptr(indices[0])); indices[0]->flatten(ctx); + Stmt *offset_stmt = indices[0]->stmt; for (int i = 1; i < (int)shape.size(); ++i) { Stmt *accumulated_stmt = ctx->back_stmt(); Stmt *current_length_stmt = @@ -307,10 +308,10 @@ void LocalTensorElementExpression::flatten(FlattenContext *ctx) { Stmt *current_index_stmt = ctx->back_stmt(); ctx->push_back(Stmt::make(BinaryOpType::add, mul_stmt, current_index_stmt)); + offset_stmt = ctx->back_stmt(); } // Type A[x, y, ...] // ^^^^ - Stmt *offset_stmt = ctx->back_stmt(); Stmt *dt_size_stmt = ctx->push_back( Stmt::make(TypedConstant(data_type_size(element_type)))); ctx->push_back( diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index f9a8e5dc35156..9f3215ef1e742 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -14,6 +14,28 @@ TLANG_NAMESPACE_BEGIN namespace irpass { namespace { +class SquashPtrOffset : public IRVisitor { + public: + SquashPtrOffset() { + allow_undefined_visitor = true; + invoke_default_visitor = true; + } + void visit(Stmt *stmt) override { + top_level_ptr = stmt; + } + void visit(PtrOffsetStmt *stmt) override { + stmt->origin->accept(this); + } + static Stmt *run(Stmt *root) { + SquashPtrOffset v; + root->accept(&v); + return v.top_level_ptr; + } + + private: + Stmt *top_level_ptr = nullptr; +}; + // Offloaded local variables to its offset in the global tmps memory. using StmtToOffsetMap = std::unordered_map; @@ -302,28 +324,16 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { return; if (stmt_to_offloaded[stmt] == current_offloaded) return; - if (config.advanced_optimization) { - if (stmt->is()) { - // Directly insert copies of ConstStmts later - return; - } - } - if (stmt->is()) { - // We don't support storing a pointer for now. + // Directly insert copies of ConstStmts later + if (stmt->is()) + return; + auto top_level_ptr = SquashPtrOffset::run(stmt); + // We don't support storing a pointer for now. + if (top_level_ptr->is()) return; - } // Not yet allocated - if (stmt->is()) { - if (local_to_global.find(stmt->cast()->origin) == - local_to_global.end()) { - auto alloca_stmt = stmt->cast()->origin; - local_to_global[alloca_stmt] = allocate_global(alloca_stmt->ret_type); - } - } else { - // stmt might be AllocaStmt, ExternalTensorShapeAlongAxisStmt - if (local_to_global.find(stmt) == local_to_global.end()) { - local_to_global[stmt] = allocate_global(stmt->ret_type); - } + if (local_to_global.find(top_level_ptr) == local_to_global.end()) { + local_to_global[top_level_ptr] = allocate_global(top_level_ptr->ret_type); } } @@ -421,14 +431,28 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { if (stmt->body) stmt->body->accept(this); if (stmt->task_type == OffloadedStmt::TaskType::range_for) { - if (!stmt->const_begin) + if (!stmt->const_begin) { + TI_ASSERT(offloaded_ranges_->begin_stmts.find(stmt) != + offloaded_ranges_->begin_stmts.end()) + TI_ASSERT_INFO(local_to_global_offset.find( + offloaded_ranges_->begin_stmts.find(stmt)->second) != + local_to_global_offset.end(), + "Begin fails.") stmt->begin_offset = local_to_global_offset[offloaded_ranges_->begin_stmts.find(stmt) ->second]; - if (!stmt->const_end) + } + if (!stmt->const_end) { + TI_ASSERT(offloaded_ranges_->end_stmts.find(stmt) != + offloaded_ranges_->end_stmts.end()) + TI_ASSERT_INFO(local_to_global_offset.find( + offloaded_ranges_->end_stmts.find(stmt)->second) != + local_to_global_offset.end(), + "End fails.") stmt->end_offset = local_to_global_offset[offloaded_ranges_->end_stmts.find(stmt) ->second]; + } } } @@ -441,10 +465,13 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { local_to_global_vector_type[stmt] = ret_type; auto ptr = replacement.push_back( local_to_global_offset[stmt], ret_type); + auto offloaded = stmt_to_offloaded[stmt]; + stmt_to_offloaded[ptr] = offloaded; if (auto tensor_type = stmt->ret_type->cast()) { LaneAttribute zero(std::vector( 1, TypedConstant(tensor_type->get_element_type()))); auto const_zero_stmt = replacement.push_back(zero); + stmt_to_offloaded[const_zero_stmt] = offloaded; for (int i = 0; i < tensor_type->get_num_elements(); ++i) { LaneAttribute offset(std::vector( 1, TypedConstant(i * @@ -452,155 +479,54 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { auto const_offset_stmt = replacement.push_back(offset); auto ptr_offset_stmt = replacement.push_back(ptr, const_offset_stmt); - replacement.push_back(ptr_offset_stmt, - const_zero_stmt); + auto global_store_stmt = replacement.push_back( + ptr_offset_stmt, const_zero_stmt); + stmt_to_offloaded[const_offset_stmt] = offloaded; + stmt_to_offloaded[ptr_offset_stmt] = offloaded; + stmt_to_offloaded[global_store_stmt] = offloaded; } } else { LaneAttribute zeros(std::vector( stmt->width(), TypedConstant(stmt->ret_type))); auto const_zeros = replacement.push_back(zeros); - replacement.push_back(ptr, const_zeros); + auto global_store_stmt = + replacement.push_back(ptr, const_zeros); + stmt_to_offloaded[global_store_stmt] = offloaded; } stmt->parent->replace_with(stmt, std::move(replacement), false); + // To deal with the same offloaded visit_operand() + stmt_to_offloaded[stmt] = nullptr; throw IRModified(); } - void visit(PtrOffsetStmt *stmt) override { - auto alloca = stmt->origin; - auto offset = stmt->offset; - - if (stmt_to_offloaded[offset] != stmt_to_offloaded[stmt]) { + // Replace local LD/ST with global LD/ST + void visit(LocalLoadStmt *stmt) override { + generic_visit(stmt); + TI_ASSERT(stmt->width() == 1) + auto ptr = stmt->src[0].var; + auto top_level_ptr = SquashPtrOffset::run(ptr); + if (top_level_ptr->is()) { VecStatement replacement; - // TODO: offset may not be ConstStmt - auto copy_stmt = offset->as()->copy(); - auto copy = replacement.push_back(std::move(copy_stmt)); - stmt_to_offloaded[copy] = stmt_to_offloaded[stmt]; - auto ptr_offset = replacement.push_back(alloca, copy); + auto global_load = replacement.push_back(ptr); + stmt_to_offloaded[global_load] = stmt_to_offloaded[stmt]; stmt->parent->replace_with(stmt, std::move(replacement)); throw IRModified(); } - - if (!(alloca->is() && - alloca->cast()->ret_type->is())) - return; - if (local_to_global_offset.find(alloca) == local_to_global_offset.end()) - return; - - VecStatement replacement; - auto ret_type = alloca->cast()->ret_type; - - auto ptr = replacement.push_back( - local_to_global_offset[alloca], ret_type); - replacement.push_back(ptr, stmt->offset); - - stmt->parent->replace_with(stmt, std::move(replacement)); - throw IRModified(); - } - - // Replace local LD/ST with global LD/ST - void visit(LocalLoadStmt *stmt) override { - TI_ASSERT(stmt->width() == 1); - // TensorType Alloca - if (stmt->src[0].var->is() && - stmt->src[0].var->as()->origin->is()) { - auto alloca = - stmt->src[0].var->as()->origin->as(); - if (local_to_global_offset.find(alloca) != local_to_global_offset.end()) { - // Converted to GlobalTemporaryStmt - VecStatement replacement; - auto ptr = replacement.push_back( - local_to_global_offset[alloca], alloca->ret_type); - // TODO: offset may not be ConstStmt - auto copy_stmt = stmt->src[0] - .var->as() - ->offset->as() - ->copy(); - auto copy = replacement.push_back(std::move(copy_stmt)); - auto ptr_offset = replacement.push_back(ptr, copy); - replacement.push_back(ptr_offset); - - stmt->parent->replace_with(stmt, std::move(replacement)); - throw IRModified(); - } - } - // Scalar Alloca - auto alloca = stmt->src[0].var; - if (local_to_global_offset.find(alloca) == local_to_global_offset.end()) - return; - - VecStatement replacement; - auto ret_type = stmt->ret_type; - - auto ptr = replacement.push_back( - local_to_global_offset[alloca], ret_type); - replacement.push_back(ptr); - - stmt->parent->replace_with(stmt, std::move(replacement)); - throw IRModified(); } void visit(LocalStoreStmt *stmt) override { - if (visit_operand(stmt, stmt->locate_operand(&stmt->val))) + generic_visit(stmt); + auto ptr = stmt->dest; + auto top_level_ptr = SquashPtrOffset::run(ptr); + if (top_level_ptr->is()) { + VecStatement replacement; + auto global_store = + replacement.push_back(ptr, stmt->val); + stmt_to_offloaded[global_store] = stmt_to_offloaded[stmt]; + stmt->parent->replace_with(stmt, std::move(replacement)); throw IRModified(); - TI_ASSERT(stmt->width() == 1); - // TensorType Alloca - if (stmt->dest->is() && - stmt->dest->as()->origin->is()) { - auto alloca = stmt->dest->as()->origin->as(); - if (local_to_global_offset.find(alloca) != local_to_global_offset.end()) { - // Converted to GlobalTemporaryStmt - VecStatement replacement; - auto ptr = replacement.push_back( - local_to_global_offset[alloca], alloca->ret_type); - // TODO: offset may not be ConstStmt - auto copy_stmt = - stmt->dest->as()->offset->as()->copy(); - auto copy = replacement.push_back(std::move(copy_stmt)); - auto ptr_offset = replacement.push_back(ptr, copy); - replacement.push_back(ptr_offset, stmt->val); - - stmt->parent->replace_with(stmt, std::move(replacement)); - throw IRModified(); - } - } - // Scalar Alloca - auto alloca = stmt->dest; - if (local_to_global_offset.find(alloca) == local_to_global_offset.end()) - return; - - VecStatement replacement; - auto ret_type = stmt->ret_type; - - auto ptr = replacement.push_back( - local_to_global_offset[alloca], ret_type); - replacement.push_back(ptr, stmt->val); - - stmt->parent->replace_with(stmt, std::move(replacement)); - throw IRModified(); - } - - void visit(AtomicOpStmt *stmt) override { - if (!stmt->dest->is()) { - generic_visit(stmt); - return; } - if (visit_operand(stmt, stmt->locate_operand(&stmt->val))) - throw IRModified(); - TI_ASSERT(stmt->width() == 1); - auto alloca = stmt->dest; - if (local_to_global_offset.find(alloca) == local_to_global_offset.end()) - return; - - VecStatement replacement; - auto ret_type = stmt->dest->ret_type; - - auto ptr = replacement.push_back( - local_to_global_offset[alloca], ret_type); - replacement.push_back(stmt->op_type, ptr, stmt->val); - - stmt->parent->replace_with(stmt, std::move(replacement)); - throw IRModified(); } bool visit_operand(Stmt *stmt, int index) { @@ -611,55 +537,45 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { return false; if (stmt_to_offloaded[stmt] == stmt_to_offloaded[op]) // same OffloadedStmt return false; - if (config.advanced_optimization) { - if (op->is()) { - auto copy = op->as()->copy(); - stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt]; - stmt->set_operand(index, copy.get()); - stmt->insert_before_me(std::move(copy)); - return true; - } - } + + auto offloaded = stmt_to_offloaded[stmt]; + if (op->is()) { auto copy = op->clone(); copy->as()->activate = false; - stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt]; + stmt_to_offloaded[copy.get()] = offloaded; stmt->set_operand(index, copy.get()); stmt->insert_before_me(std::move(copy)); return true; } - if (op->is() && - op->cast()->origin->is()) { - auto alloca = op->as()->origin->as(); - if (local_to_global_offset.find(alloca) == local_to_global_offset.end()) - return false; - - auto global_temporary_stmt = Stmt::make( - local_to_global_offset[alloca], alloca->ret_type); - // TODO: offset may not be ConstStmt - auto copy_stmt = op->as()->offset->as()->copy(); - auto ptr_offset_stmt = Stmt::make( - global_temporary_stmt.get(), copy_stmt.get()); - stmt_to_offloaded[copy_stmt.get()] = stmt_to_offloaded[stmt]; - stmt_to_offloaded[ptr_offset_stmt.get()] = stmt_to_offloaded[stmt]; - stmt->set_operand(index, ptr_offset_stmt.get()); - stmt->insert_before_me(std::move(copy_stmt)); - stmt->insert_before_me(std::move(global_temporary_stmt)); - stmt->insert_before_me(std::move(ptr_offset_stmt)); - return true; + if (local_to_global_offset.find(op) == local_to_global_offset.end()) { + TI_ASSERT_INFO(op->is() || op->is() || + op->is(), + "{} is not allowed here.", op->type()); + // For cases like ConstStmt + auto copy = op->clone(); + stmt_to_offloaded[copy.get()] = offloaded; + stmt->set_operand(index, copy.get()); + stmt->insert_before_me(std::move(copy)); + } else { + auto global_temporary = Stmt::make( + local_to_global_offset[op], op->ret_type); + stmt_to_offloaded[global_temporary.get()] = offloaded; + stmt->set_operand(index, global_temporary.get()); + if (op->is() || op->ret_type.is_pointer()) { + // For cases like Alloca both TensorType and Scalar which will be + // followed by LocalLoad. Avoid repeated loads here. + stmt->insert_before_me(std::move(global_temporary)); + } else { + // For other cases like ArgLoadStmt UnaryOpStmt which needs to load. + auto load = Stmt::make(global_temporary.get()); + stmt_to_offloaded[load.get()] = offloaded; + stmt->set_operand(index, load.get()); + stmt->insert_before_me(std::move(global_temporary)); + stmt->insert_before_me(std::move(load)); + } } - - if (local_to_global_offset.find(op) == local_to_global_offset.end()) - return false; - - auto global = Stmt::make(local_to_global_offset[op], - op->ret_type); - auto load = Stmt::make(global.get()); - stmt_to_offloaded[load.get()] = stmt_to_offloaded[stmt]; - stmt->set_operand(index, load.get()); - stmt->insert_before_me(std::move(global)); - stmt->insert_before_me(std::move(load)); return true; } @@ -675,7 +591,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { } void visit(Stmt *stmt) override { - TI_ASSERT(stmt->width() == 1 || (stmt->ret_type->is())); + TI_ASSERT(stmt->width() == 1) generic_visit(stmt); } diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 71df8bc45075b..ce85c8557efc0 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -195,6 +195,9 @@ def func3(): tmp = ti.Vector([1, 2, 3], dt=ti.i32) for i in range(3): tmp[i] = i * i + vec = ti.Vector([4, 5, 6], dt=ti.i32) + for j in range(3): + vec[tmp[i] % 3] += vec[j % 3] assert tmp[0] == 0 assert tmp[1] == 1 assert tmp[2] == 4