diff --git a/taichi/analysis/clone.cpp b/taichi/analysis/clone.cpp index 2d76813a973cb..0b6754f0eea2b 100644 --- a/taichi/analysis/clone.cpp +++ b/taichi/analysis/clone.cpp @@ -134,6 +134,14 @@ namespace irpass::analysis { std::unique_ptr clone(IRNode *root) { return IRCloner::run(root); } + +std::unique_ptr clone(Stmt *root) { + auto ret = IRCloner::run(root); + Stmt *stmt_ptr = dynamic_cast(ret.release()); + TI_ASSERT(stmt_ptr != nullptr); + + return std::unique_ptr(stmt_ptr); +} } // namespace irpass::analysis } // namespace taichi::lang diff --git a/taichi/codegen/amdgpu/codegen_amdgpu.cpp b/taichi/codegen/amdgpu/codegen_amdgpu.cpp index 5b2431ef54eb5..5056334857552 100644 --- a/taichi/codegen/amdgpu/codegen_amdgpu.cpp +++ b/taichi/codegen/amdgpu/codegen_amdgpu.cpp @@ -472,9 +472,9 @@ LLVMCompiledTask KernelCodeGenAMDGPU::compile_task( int task_codegen_id, const CompileConfig &config, std::unique_ptr &&module, - OffloadedStmt *stmt) { + IRNode *block) { TaskCodeGenAMDGPU gen(task_codegen_id, config, get_taichi_llvm_context(), - kernel, stmt); + kernel, block); return gen.run_compilation(); } diff --git a/taichi/codegen/amdgpu/codegen_amdgpu.h b/taichi/codegen/amdgpu/codegen_amdgpu.h index aa75b16025377..8af8658ccb75d 100644 --- a/taichi/codegen/amdgpu/codegen_amdgpu.h +++ b/taichi/codegen/amdgpu/codegen_amdgpu.h @@ -22,7 +22,7 @@ class KernelCodeGenAMDGPU : public KernelCodeGen { int task_codegen_id, const CompileConfig &config, std::unique_ptr &&module = nullptr, - OffloadedStmt *stmt = nullptr) override; + IRNode *block = nullptr) override; #endif // TI_WITH_LLVM }; diff --git a/taichi/codegen/codegen.cpp b/taichi/codegen/codegen.cpp index 9aec49f17cbed..89c718086340c 100644 --- a/taichi/codegen/codegen.cpp +++ b/taichi/codegen/codegen.cpp @@ -86,8 +86,10 @@ LLVMCompiledKernel KernelCodeGen::compile_kernel_to_module() { tlctx_.fetch_this_thread_struct_module(); auto offload = irpass::analysis::clone(offloads[i].get()); irpass::re_id(offload.get()); - auto new_data = this->compile_task(i, compile_config_, nullptr, - offload->as()); + + Block blk; + blk.insert(std::move(offload)); + auto new_data = this->compile_task(i, compile_config_, nullptr, &blk); data[i] = std::make_unique(std::move(new_data)); }; worker.enqueue(compile_func); diff --git a/taichi/codegen/codegen.h b/taichi/codegen/codegen.h index 30a6099461b1f..e2c925287b8d5 100644 --- a/taichi/codegen/codegen.h +++ b/taichi/codegen/codegen.h @@ -63,7 +63,7 @@ class KernelCodeGen { int task_codegen_id, const CompileConfig &config, std::unique_ptr &&module = nullptr, - OffloadedStmt *stmt = nullptr) { + IRNode *block = nullptr) { TI_NOT_IMPLEMENTED } diff --git a/taichi/codegen/cpu/codegen_cpu.cpp b/taichi/codegen/cpu/codegen_cpu.cpp index 5cd1ac5906196..aa7ac005a4ac4 100644 --- a/taichi/codegen/cpu/codegen_cpu.cpp +++ b/taichi/codegen/cpu/codegen_cpu.cpp @@ -234,9 +234,9 @@ LLVMCompiledTask KernelCodeGenCPU::compile_task( int task_codegen_id, const CompileConfig &config, std::unique_ptr &&module, - OffloadedStmt *stmt) { + IRNode *block) { TaskCodeGenCPU gen(task_codegen_id, config, get_taichi_llvm_context(), kernel, - stmt); + block); return gen.run_compilation(); } diff --git a/taichi/codegen/cpu/codegen_cpu.h b/taichi/codegen/cpu/codegen_cpu.h index 8dfd3550008b2..e23c10e746014 100644 --- a/taichi/codegen/cpu/codegen_cpu.h +++ b/taichi/codegen/cpu/codegen_cpu.h @@ -24,7 +24,7 @@ class KernelCodeGenCPU : public KernelCodeGen { int task_codegen_id, const CompileConfig &config, std::unique_ptr &&module = nullptr, - OffloadedStmt *stmt = nullptr) override; + IRNode *block = nullptr) override; protected: void optimize_module(llvm::Module *module) override; diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 9ffb507b3905a..e195c2f8373c3 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -759,9 +759,9 @@ LLVMCompiledTask KernelCodeGenCUDA::compile_task( int task_codegen_id, const CompileConfig &config, std::unique_ptr &&module, - OffloadedStmt *stmt) { + IRNode *block) { TaskCodeGenCUDA gen(task_codegen_id, config, get_taichi_llvm_context(), - kernel, stmt); + kernel, block); return gen.run_compilation(); } diff --git a/taichi/codegen/cuda/codegen_cuda.h b/taichi/codegen/cuda/codegen_cuda.h index be9a3eca14581..c694650093f45 100644 --- a/taichi/codegen/cuda/codegen_cuda.h +++ b/taichi/codegen/cuda/codegen_cuda.h @@ -22,7 +22,7 @@ class KernelCodeGenCUDA : public KernelCodeGen { int task_codegen_id, const CompileConfig &config, std::unique_ptr &&module = nullptr, - OffloadedStmt *stmt = nullptr) override; + IRNode *block = nullptr) override; #endif // TI_WITH_LLVM }; diff --git a/taichi/codegen/dx12/codegen_dx12.cpp b/taichi/codegen/dx12/codegen_dx12.cpp index 925b402d3f9c7..6c90471600c67 100644 --- a/taichi/codegen/dx12/codegen_dx12.cpp +++ b/taichi/codegen/dx12/codegen_dx12.cpp @@ -251,16 +251,18 @@ KernelCodeGenDX12::CompileResult KernelCodeGenDX12::compile() { for (int i = 0; i < offloads.size(); i++) { auto offload = irpass::analysis::clone(offloads[i].get()); irpass::re_id(offload.get()); - auto *offload_stmt = offload->as(); - auto new_data = compile_task(i, config, nullptr, offload_stmt); + auto offload_name = offload->as()->task_name(); + + Block blk; + blk.insert(std::move(offload)); + auto new_data = compile_task(i, config, nullptr, &blk); Result.task_dxil_source_codes.emplace_back( generate_dxil_from_llvm(new_data, config, kernel)); aot::CompiledOffloadedTask task; // FIXME: build all fields for task. - task.name = fmt::format("{}_{}_{}", kernel->get_name(), - offload_stmt->task_name(), i); - task.type = offload_stmt->task_name(); + task.name = fmt::format("{}_{}_{}", kernel->get_name(), offload_name, i); + task.type = offload_name; Result.tasks.emplace_back(task); } // FIXME: set correct num_snode_trees. @@ -272,9 +274,9 @@ LLVMCompiledTask KernelCodeGenDX12::compile_task( int task_codegen_id, const CompileConfig &config, std::unique_ptr &&module, - OffloadedStmt *stmt) { + IRNode *block) { TaskCodeGenLLVMDX12 gen(task_codegen_id, config, get_taichi_llvm_context(), - kernel, stmt); + kernel, block); return gen.run_compilation(); } #endif // TI_WITH_LLVM diff --git a/taichi/codegen/dx12/codegen_dx12.h b/taichi/codegen/dx12/codegen_dx12.h index 30bb693abe79a..f7c84fc8395ea 100644 --- a/taichi/codegen/dx12/codegen_dx12.h +++ b/taichi/codegen/dx12/codegen_dx12.h @@ -29,7 +29,7 @@ class KernelCodeGenDX12 : public KernelCodeGen { int task_codegen_id, const CompileConfig &config, std::unique_ptr &&module = nullptr, - OffloadedStmt *stmt = nullptr) override; + IRNode *block = nullptr) override; #endif }; diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index 3161d38991420..48eb0b4cae128 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -77,6 +77,7 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2); std::unique_ptr build_cfg(IRNode *root); void check_fields_registered(IRNode *root); std::unique_ptr clone(IRNode *root); +std::unique_ptr clone(Stmt *root); int count_statements(IRNode *root); /** diff --git a/tests/python/test_continue.py b/tests/python/test_continue.py index cdb13d54cd73a..a837112703809 100644 --- a/tests/python/test_continue.py +++ b/tests/python/test_continue.py @@ -147,3 +147,26 @@ def run(a: ti.i32): assert x[0] == 1 run(0) assert x[0] == 0 + + +@test_utils.test() +def test_kernel_continue_in_simple_if(): + img = ti.field(ti.i32, (2, 2)) + + @ti.kernel + def K(): + for i, j in img: + img[i, j] = 0 + if i > 0 or j > 0: + continue + img[i, j] = 1 + + img.fill(2) + K() + + for i in range(2): + for j in range(2): + if i > 0 or j > 0: + assert img[i, j] == 0 + else: + assert img[i, j] == 1