Skip to content

Commit

Permalink
[refactor] Remove dependencies on Program::current_ast_builder() in C…
Browse files Browse the repository at this point in the history
…++ side (#7044)

Issue: #7002
  • Loading branch information
PGZXB authored Jan 5, 2023
1 parent a034184 commit 8066f43
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 24 deletions.
16 changes: 8 additions & 8 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,17 +359,17 @@ Arch Program::get_accessor_arch() {
Kernel &Program::get_snode_reader(SNode *snode) {
TI_ASSERT(snode->type == SNodeType::place);
auto kernel_name = fmt::format("snode_reader_{}", snode->id);
auto &ker = kernel([snode, this] {
auto &ker = kernel([snode, this](Kernel *kernel) {
ExprGroup indices;
for (int i = 0; i < snode->num_active_indices; i++) {
auto argload_expr = Expr::make<ArgLoadExpression>(i, PrimitiveType::i32);
argload_expr->type_check(&this->this_thread_config());
indices.push_back(std::move(argload_expr));
}
ASTBuilder *builder = this->current_ast_builder();
ASTBuilder &builder = kernel->context->builder();
auto ret = Stmt::make<FrontendReturnStmt>(ExprGroup(
builder->expr_subscript(Expr(snode_to_fields_.at(snode)), indices)));
this->current_ast_builder()->insert(std::move(ret));
builder.expr_subscript(Expr(snode_to_fields_.at(snode)), indices)));
builder.insert(std::move(ret));
});
ker.set_arch(get_accessor_arch());
ker.name = kernel_name;
Expand All @@ -383,17 +383,17 @@ Kernel &Program::get_snode_reader(SNode *snode) {
Kernel &Program::get_snode_writer(SNode *snode) {
TI_ASSERT(snode->type == SNodeType::place);
auto kernel_name = fmt::format("snode_writer_{}", snode->id);
auto &ker = kernel([snode, this] {
auto &ker = kernel([snode, this](Kernel *kernel) {
ExprGroup indices;
for (int i = 0; i < snode->num_active_indices; i++) {
auto argload_expr = Expr::make<ArgLoadExpression>(i, PrimitiveType::i32);
argload_expr->type_check(&this->this_thread_config());
indices.push_back(std::move(argload_expr));
}
ASTBuilder *builder = current_ast_builder();
ASTBuilder &builder = kernel->context->builder();
auto expr =
builder->expr_subscript(Expr(snode_to_fields_.at(snode)), indices);
this->current_ast_builder()->insert_assignment(
builder.expr_subscript(Expr(snode_to_fields_.at(snode)), indices);
builder.insert_assignment(
expr,
Expr::make<ArgLoadExpression>(snode->num_active_indices,
snode->dt->get_compute_type()),
Expand Down
10 changes: 0 additions & 10 deletions taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,6 @@ class TI_DLL_EXPORT Program {

void visualize_layout(const std::string &fn);

Kernel &kernel(const std::function<void()> &body,
const std::string &name = "",
AutodiffMode autodiff_mode = AutodiffMode::kNone) {
// Expr::set_allow_store(true);
auto func = std::make_unique<Kernel>(*this, body, name, autodiff_mode);
// Expr::set_allow_store(false);
kernels.emplace_back(std::move(func));
return *kernels.back();
}

Kernel &kernel(const std::function<void(Kernel *)> &body,
const std::string &name = "",
AutodiffMode autodiff_mode = AutodiffMode::kNone) {
Expand Down
14 changes: 8 additions & 6 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ConstantFold : public BasicStmtVisitor {
return it->second.get();

auto kernel_name = fmt::format("jit_evaluator_{}", cache.size());
auto func = [&id, this]() {
auto func = [&id](Kernel *kernel) {
auto lhstmt =
Stmt::make<ArgLoadStmt>(/*arg_id=*/0, id.lhs, /*is_ptr=*/false);
auto rhstmt =
Expand All @@ -48,12 +48,14 @@ class ConstantFold : public BasicStmtVisitor {
oper->cast<UnaryOpStmt>()->cast_type = id.rhs;
}
}
auto &ast_builder = kernel->context->builder();
auto ret = Stmt::make<ReturnStmt>(oper.get());
program->current_ast_builder()->insert(std::move(lhstmt));
if (id.is_binary)
program->current_ast_builder()->insert(std::move(rhstmt));
program->current_ast_builder()->insert(std::move(oper));
program->current_ast_builder()->insert(std::move(ret));
ast_builder.insert(std::move(lhstmt));
if (id.is_binary) {
ast_builder.insert(std::move(rhstmt));
}
ast_builder.insert(std::move(oper));
ast_builder.insert(std::move(ret));
};

auto ker = std::make_unique<Kernel>(*program, func, kernel_name);
Expand Down

0 comments on commit 8066f43

Please sign in to comment.