diff --git a/misc/test_fuse_dense.py b/misc/test_fuse_dense.py new file mode 100644 index 0000000000000..14463fcf25139 --- /dev/null +++ b/misc/test_fuse_dense.py @@ -0,0 +1,75 @@ +import taichi as ti +import time + +ti.init(async=True) + +x = ti.var(ti.i32) +y = ti.var(ti.i32) +z = ti.var(ti.i32) + +ti.root.dense(ti.i, 1024**3).place(x) +ti.root.dense(ti.i, 1024**3).place(y) +ti.root.dense(ti.i, 1024**3).place(z) + + +@ti.kernel +def x_to_y(): + for i in x: + y[i] = x[i] + 1 + + +@ti.kernel +def y_to_z(): + for i in x: + z[i] = y[i] + 4 + + +@ti.kernel +def inc(): + for i in x: + x[i] = x[i] + 1 + + +n = 100 + +for i in range(n): + x[i] = i * 10 + +repeat = 10 + +for i in range(repeat): + t = time.time() + x_to_y() + ti.sync() + print('x_to_y', time.time() - t) + +for i in range(repeat): + t = time.time() + y_to_z() + ti.sync() + print('y_to_z', time.time() - t) + +for i in range(repeat): + t = time.time() + x_to_y() + y_to_z() + ti.sync() + print('fused x->y->z', time.time() - t) + +for i in range(repeat): + t = time.time() + inc() + ti.sync() + print('single inc', time.time() - t) + +for i in range(repeat): + t = time.time() + for j in range(10): + inc() + ti.sync() + print('fused 10 inc', time.time() - t) + +for i in range(n): + assert x[i] == i * 10 + assert y[i] == x[i] + 1 + assert z[i] == x[i] + 5 diff --git a/misc/test_fuse_dynamic.py b/misc/test_fuse_dynamic.py new file mode 100644 index 0000000000000..cc30663a7c4b8 --- /dev/null +++ b/misc/test_fuse_dynamic.py @@ -0,0 +1,35 @@ +import taichi as ti + +ti.init() + +x = ti.var(ti.i32) +y = ti.var(ti.i32) +z = ti.var(ti.i32) + +ti.root.dynamic(ti.i, 1048576, chunk_size=2048).place(x, y, z) + + +@ti.kernel +def x_to_y(): + for i in x: + y[i] = x[i] + 1 + + +@ti.kernel +def y_to_z(): + for i in x: + z[i] = y[i] + 1 + + +n = 10000 + +for i in range(n): + x[i] = i * 10 + +x_to_y() +y_to_z() + +for i in range(n): + x[i] = i * 10 + assert y[i] == x[i] + 1 + assert z[i] == x[i] + 2 diff --git a/taichi/analysis/clone.cpp b/taichi/analysis/clone.cpp index 026de6073fa85..751b0086e4795 100644 --- a/taichi/analysis/clone.cpp +++ b/taichi/analysis/clone.cpp @@ -2,6 +2,8 @@ #include "taichi/ir/analysis.h" #include "taichi/ir/transforms.h" #include "taichi/ir/visitors.h" +#include "taichi/program/program.h" + #include TLANG_NAMESPACE_BEGIN @@ -12,10 +14,7 @@ class IRCloner : public IRVisitor { std::unordered_map operand_map; public: - enum Phase { - register_operand_map, - replace_operand - } phase; + enum Phase { register_operand_map, replace_operand } phase; explicit IRCloner(IRNode *other_node) : other_node(other_node), phase(register_operand_map) { @@ -111,22 +110,25 @@ class IRCloner : public IRVisitor { } } - static std::unique_ptr run(IRNode *root) { + static std::unique_ptr run(IRNode *root, Kernel *kernel) { + if (kernel == nullptr) { + kernel = &get_current_program().get_current_kernel(); + } std::unique_ptr new_root = root->clone(); IRCloner cloner(new_root.get()); cloner.phase = IRCloner::register_operand_map; root->accept(&cloner); cloner.phase = IRCloner::replace_operand; root->accept(&cloner); - irpass::typecheck(new_root.get()); + irpass::typecheck(new_root.get(), kernel); irpass::fix_block_parents(new_root.get()); return new_root; } }; namespace irpass::analysis { -std::unique_ptr clone(IRNode *root) { - return IRCloner::run(root); +std::unique_ptr clone(IRNode *root, Kernel *kernel) { + return IRCloner::run(root, kernel); } } // namespace irpass::analysis diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index 00b483364cd95..0d08a8dda76ab 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -51,7 +51,7 @@ class DiffRange { namespace irpass::analysis { void check_fields_registered(IRNode *root); -std::unique_ptr clone(IRNode *root); +std::unique_ptr clone(IRNode *root, Kernel *kernel = nullptr); int count_statements(IRNode *root); std::unordered_set detect_fors_with_break(IRNode *root); std::unordered_set detect_loops_with_continue(IRNode *root); diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index cc1fbd9627d70..459bb2ab4bdc3 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -12,9 +12,11 @@ TLANG_NAMESPACE_BEGIN -uint64 hash(OffloadedStmt *stmt) { +uint64 hash(IRNode *stmt) { + TI_ASSERT(stmt); // TODO: upgrade this using IR comparisons std::string serialized; + irpass::re_id(stmt); irpass::print(stmt, &serialized); uint64 ret = 0; for (uint64 i = 0; i < serialized.size(); i++) { @@ -25,33 +27,41 @@ uint64 hash(OffloadedStmt *stmt) { KernelLaunchRecord::KernelLaunchRecord(Context context, Kernel *kernel, - OffloadedStmt *stmt) - : context(context), kernel(kernel), stmt(stmt), h(hash(stmt)) { + std::unique_ptr &&stmt_) + : context(context), + kernel(kernel), + stmt(dynamic_cast(stmt_.get())), + stmt_(std::move(stmt_)), + h(hash(stmt)) { + TI_ASSERT(stmt != nullptr); } -void ExecutionQueue::enqueue(KernelLaunchRecord ker) { +void ExecutionQueue::enqueue(KernelLaunchRecord &&ker) { auto h = ker.h; + auto stmt = ker.stmt; + auto kernel = ker.kernel; if (compiled_func.find(h) == compiled_func.end() && to_be_compiled.find(h) == to_be_compiled.end()) { to_be_compiled.insert(h); - compilation_workers.enqueue([&, ker, h, this]() { + compilation_workers.enqueue([&, stmt, kernel, h, this]() { { // Final lowering using namespace irpass; - flag_access(ker.stmt); - lower_access(ker.stmt, true, ker.kernel); - flag_access(ker.stmt); - full_simplify(ker.stmt, ker.kernel->program.config, ker.kernel); - // analysis::verify(ker.stmt); + flag_access(stmt); + lower_access(stmt, true, kernel); + flag_access(stmt); + full_simplify(stmt, kernel->program.config, kernel); + // analysis::verify(stmt); } - auto func = CodeGenCPU(ker.kernel, ker.stmt).codegen(); + auto func = CodeGenCPU(kernel, stmt).codegen(); std::lock_guard _(mut); compiled_func[h] = func; }); } - launch_worker.enqueue([&, ker, h] { + auto context = ker.context; + launch_worker.enqueue([&, h, stmt, context, this] { FunctionType func; while (true) { std::unique_lock lock(mut); @@ -64,7 +74,7 @@ void ExecutionQueue::enqueue(KernelLaunchRecord ker) { break; } stat.add("launched_kernels", 1.0); - auto task_type = ker.stmt->task_type; + auto task_type = stmt->task_type; if (task_type == OffloadedStmt::TaskType::listgen) { stat.add("launched_kernels_list_op", 1.0); stat.add("launched_kernels_list_gen", 1.0); @@ -80,9 +90,10 @@ void ExecutionQueue::enqueue(KernelLaunchRecord ker) { } else if (task_type == OffloadedStmt::TaskType::gc) { stat.add("launched_kernels_garbage_collect", 1.0); } - auto context = ker.context; - func(context); + auto c = context; + func(c); }); + trashbin.push_back(std::move(ker)); } void ExecutionQueue::synchronize() { @@ -102,19 +113,19 @@ void AsyncEngine::launch(Kernel *kernel) { auto &offloads = block->statements; for (std::size_t i = 0; i < offloads.size(); i++) { auto offload = offloads[i]->as(); - KernelLaunchRecord rec(kernel->program.get_context(), kernel, offload); - enqueue(rec); + KernelLaunchRecord rec(kernel->program.get_context(), kernel, + irpass::analysis::clone(offload, kernel)); + enqueue(std::move(rec)); } } -void AsyncEngine::enqueue(KernelLaunchRecord t) { +void AsyncEngine::enqueue(KernelLaunchRecord &&t) { using namespace irpass::analysis; - task_queue.push_back(t); - auto &meta = metas[t.h]; // TODO: this is an abuse since it gathers nothing... - gather_statements(t.stmt, [&](Stmt *stmt) { + auto root_stmt = t.stmt; + gather_statements(root_stmt, [&](Stmt *stmt) { if (auto global_ptr = stmt->cast()) { for (auto &snode : global_ptr->snodes.data) { meta.input_snodes.insert(snode); @@ -153,25 +164,29 @@ void AsyncEngine::enqueue(KernelLaunchRecord t) { } return false; }); + + task_queue.push_back(std::move(t)); } void AsyncEngine::synchronize() { - optimize(); + optimize_listgen(); + while (fuse()) + ; while (!task_queue.empty()) { - queue.enqueue(task_queue.front()); + queue.enqueue(std::move(task_queue.front())); task_queue.pop_front(); } queue.synchronize(); } -bool AsyncEngine::optimize() { +bool AsyncEngine::optimize_listgen() { // TODO: improve... bool modified = false; std::unordered_map list_dirty; auto new_task_queue = std::deque(); for (int i = 0; i < task_queue.size(); i++) { // Try to eliminate unused listgens - auto t = task_queue[i]; + auto &t = task_queue[i]; auto meta = metas[t.h]; auto offload = t.stmt; bool keep = true; @@ -197,7 +212,7 @@ bool AsyncEngine::optimize() { } } if (keep) { - new_task_queue.push_back(t); + new_task_queue.push_back(std::move(t)); } else { modified = true; } @@ -206,4 +221,78 @@ bool AsyncEngine::optimize() { return modified; } +bool AsyncEngine::fuse() { + // TODO: improve... + bool modified = false; + std::unordered_map list_dirty; + + if (false) { + // (experimental) print tasks + for (int i = 0; i < (int)task_queue.size(); i++) { + fmt::print("{}: {}\n", i, task_queue[i].stmt->task_name()); + irpass::print(task_queue[i].stmt); + } + } + + for (int i = 0; i < (int)task_queue.size() - 1; i++) { + auto task_a = task_queue[i].stmt; + auto task_b = task_queue[i + 1].stmt; + bool is_same_struct_for = task_a->task_type == OffloadedStmt::struct_for && + task_b->task_type == OffloadedStmt::struct_for && + task_a->snode == task_b->snode && + task_a->block_dim == task_b->block_dim; + bool is_same_range_for = task_a->task_type == OffloadedStmt::range_for && + task_b->task_type == OffloadedStmt::range_for && + task_a->const_begin && task_b->const_begin && + task_a->const_end && task_b->const_end && + task_a->begin_value == task_b->begin_value && + task_a->end_value == task_b->end_value; + + // We do not fuse serial kernels for now since they can be SNode accessors + bool are_both_serial = task_a->task_type == OffloadedStmt::serial && + task_b->task_type == OffloadedStmt::serial; + bool same_kernel = task_queue[i].kernel == task_queue[i + 1].kernel; + if (is_same_range_for || is_same_struct_for) { + // TODO: in certain cases this optimization can be wrong! + // Fuse task b into task_a + for (int j = 0; j < (int)task_b->body->size(); j++) { + task_a->body->insert(std::move(task_b->body->statements[j])); + } + task_b->body->statements.clear(); + + // replace all reference to the offloaded statement B to A + irpass::replace_all_usages_with(task_a, task_b, task_a); + irpass::re_id(task_a); + irpass::fix_block_parents(task_a); + + auto kernel = task_queue[i].kernel; + irpass::full_simplify(task_a, kernel->program.config, kernel); + task_queue[i].h = hash(task_a); + + modified = true; + } + } + + auto new_task_queue = std::deque(); + + // Eliminate empty tasks + for (int i = 0; i < (int)task_queue.size(); i++) { + auto task = task_queue[i].stmt; + bool keep = true; + if (task->task_type == OffloadedStmt::struct_for || + task->task_type == OffloadedStmt::range_for || + task->task_type == OffloadedStmt::serial) { + if (task->body->statements.empty()) + keep = false; + } + if (keep) { + new_task_queue.push_back(std::move(task_queue[i])); + } + } + + task_queue = std::move(new_task_queue); + + return modified; +} + TLANG_NAMESPACE_END diff --git a/taichi/program/async_engine.h b/taichi/program/async_engine.h index 4349aa35713cf..3b80908b77a91 100644 --- a/taichi/program/async_engine.h +++ b/taichi/program/async_engine.h @@ -109,9 +109,12 @@ class KernelLaunchRecord { Context context; Kernel *kernel; // TODO: remove this OffloadedStmt *stmt; + std::unique_ptr stmt_; uint64 h; - KernelLaunchRecord(Context contxet, Kernel *kernel, OffloadedStmt *stmt); + KernelLaunchRecord(Context contxet, + Kernel *kernel, + std::unique_ptr &&stmt); }; // In charge of (parallel) compilation to binary and (serial) kernel launching @@ -119,6 +122,7 @@ class ExecutionQueue { public: std::mutex mut; std::deque task_queue; + std::vector trashbin; // prevent IR from being deleted std::unordered_set to_be_compiled; ParallelExecutor compilation_workers; // parallel compilation @@ -128,7 +132,7 @@ class ExecutionQueue { ExecutionQueue(); - void enqueue(KernelLaunchRecord ker); + void enqueue(KernelLaunchRecord &&ker); void compile_task() { } @@ -163,7 +167,9 @@ class AsyncEngine { AsyncEngine() { } - bool optimize(); // return true when modified + bool optimize_listgen(); // return true when modified + + bool fuse(); // return true when modified void clear_cache() { queue.clear_cache(); @@ -171,7 +177,7 @@ class AsyncEngine { void launch(Kernel *kernel); - void enqueue(KernelLaunchRecord t); + void enqueue(KernelLaunchRecord &&t); void synchronize(); };