From 8066f43e5d25cc87694ea8a8f4835d308168dabe Mon Sep 17 00:00:00 2001 From: PGZXB Date: Thu, 5 Jan 2023 19:31:58 +0800 Subject: [PATCH] [refactor] Remove dependencies on Program::current_ast_builder() in C++ side (#7044) Issue: #7002 --- taichi/program/program.cpp | 16 ++++++++-------- taichi/program/program.h | 10 ---------- taichi/transforms/constant_fold.cpp | 14 ++++++++------ 3 files changed, 16 insertions(+), 24 deletions(-) diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index ef9ab892b9e4f..98d7a030e6df3 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -359,17 +359,17 @@ Arch Program::get_accessor_arch() { Kernel &Program::get_snode_reader(SNode *snode) { TI_ASSERT(snode->type == SNodeType::place); auto kernel_name = fmt::format("snode_reader_{}", snode->id); - auto &ker = kernel([snode, this] { + auto &ker = kernel([snode, this](Kernel *kernel) { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { auto argload_expr = Expr::make(i, PrimitiveType::i32); argload_expr->type_check(&this->this_thread_config()); indices.push_back(std::move(argload_expr)); } - ASTBuilder *builder = this->current_ast_builder(); + ASTBuilder &builder = kernel->context->builder(); auto ret = Stmt::make(ExprGroup( - builder->expr_subscript(Expr(snode_to_fields_.at(snode)), indices))); - this->current_ast_builder()->insert(std::move(ret)); + builder.expr_subscript(Expr(snode_to_fields_.at(snode)), indices))); + builder.insert(std::move(ret)); }); ker.set_arch(get_accessor_arch()); ker.name = kernel_name; @@ -383,17 +383,17 @@ Kernel &Program::get_snode_reader(SNode *snode) { Kernel &Program::get_snode_writer(SNode *snode) { TI_ASSERT(snode->type == SNodeType::place); auto kernel_name = fmt::format("snode_writer_{}", snode->id); - auto &ker = kernel([snode, this] { + auto &ker = kernel([snode, this](Kernel *kernel) { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { auto argload_expr = Expr::make(i, PrimitiveType::i32); argload_expr->type_check(&this->this_thread_config()); indices.push_back(std::move(argload_expr)); } - ASTBuilder *builder = current_ast_builder(); + ASTBuilder &builder = kernel->context->builder(); auto expr = - builder->expr_subscript(Expr(snode_to_fields_.at(snode)), indices); - this->current_ast_builder()->insert_assignment( + builder.expr_subscript(Expr(snode_to_fields_.at(snode)), indices); + builder.insert_assignment( expr, Expr::make(snode->num_active_indices, snode->dt->get_compute_type()), diff --git a/taichi/program/program.h b/taichi/program/program.h index 5ca58d0bb0b34..121e43277f0d7 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -184,16 +184,6 @@ class TI_DLL_EXPORT Program { void visualize_layout(const std::string &fn); - Kernel &kernel(const std::function &body, - const std::string &name = "", - AutodiffMode autodiff_mode = AutodiffMode::kNone) { - // Expr::set_allow_store(true); - auto func = std::make_unique(*this, body, name, autodiff_mode); - // Expr::set_allow_store(false); - kernels.emplace_back(std::move(func)); - return *kernels.back(); - } - Kernel &kernel(const std::function &body, const std::string &name = "", AutodiffMode autodiff_mode = AutodiffMode::kNone) { diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index 1ec03c5b6c782..6fe50f2d33911 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -32,7 +32,7 @@ class ConstantFold : public BasicStmtVisitor { return it->second.get(); auto kernel_name = fmt::format("jit_evaluator_{}", cache.size()); - auto func = [&id, this]() { + auto func = [&id](Kernel *kernel) { auto lhstmt = Stmt::make(/*arg_id=*/0, id.lhs, /*is_ptr=*/false); auto rhstmt = @@ -48,12 +48,14 @@ class ConstantFold : public BasicStmtVisitor { oper->cast()->cast_type = id.rhs; } } + auto &ast_builder = kernel->context->builder(); auto ret = Stmt::make(oper.get()); - program->current_ast_builder()->insert(std::move(lhstmt)); - if (id.is_binary) - program->current_ast_builder()->insert(std::move(rhstmt)); - program->current_ast_builder()->insert(std::move(oper)); - program->current_ast_builder()->insert(std::move(ret)); + ast_builder.insert(std::move(lhstmt)); + if (id.is_binary) { + ast_builder.insert(std::move(rhstmt)); + } + ast_builder.insert(std::move(oper)); + ast_builder.insert(std::move(ret)); }; auto ker = std::make_unique(*program, func, kernel_name);