Skip to content

Commit

Permalink
[refactor] Split constructing and compilation of lang::Function (#7209)
Browse files Browse the repository at this point in the history
Issue: #7002 

### Brief Summary
Removed dependencies on `Program::this_thread_config()` in
`lang::Function` compilation (AST->IR part)
* Push off the compilation of `lang::Function`: Introduce the
`irpass::compile_called_function(IRNode *root, const CompileConfig
&config)`, which compiles the AST/IR of `Function`s called in `root` to
the final IR.
  • Loading branch information
PGZXB authored Jan 19, 2023
1 parent b13b28c commit f760fbb
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 9 deletions.
2 changes: 2 additions & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ void compile_function(IRNode *ir,
void ast_to_ir(const CompileConfig &config,
Kernel &kernel,
bool to_executable = true);

void compile_taichi_functions(IRNode *ir, const CompileConfig &compile_config);
} // namespace irpass

} // namespace taichi::lang
10 changes: 2 additions & 8 deletions taichi/program/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Function::Function(Program *program, const FunctionKey &func_key)
void Function::set_function_body(const std::function<void()> &func) {
context = std::make_unique<FrontendContext>(program->compile_config().arch);
ir = context->get_root();
ir_type_ = IRType::AST;

func();

Expand All @@ -21,18 +22,11 @@ void Function::set_function_body(const std::function<void()> &func) {
gen_offline_cache_key(program, ir.get(), &oss);
ast_serialization_data_ = oss.str();
}
irpass::compile_function(ir.get(), program->compile_config(), this,
/*autodiff_mode=*/AutodiffMode::kNone,
/*verbose=*/program->compile_config().print_ir,
/*start_from_ast=*/true);
}

void Function::set_function_body(std::unique_ptr<IRNode> func_body) {
ir = std::move(func_body);
irpass::compile_function(ir.get(), program->compile_config(), this,
/*autodiff_mode=*/AutodiffMode::kNone,
/*verbose=*/program->compile_config().print_ir,
/*start_from_ast=*/false);
ir_type_ = IRType::InitialIR;
}

std::string Function::get_name() const {
Expand Down
13 changes: 12 additions & 1 deletion taichi/program/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ class Program;

class Function : public Callable {
public:
enum class IRType { None, AST, InitialIR, OptimizedIR };

FunctionKey func_key;

Function(Program *program, const FunctionKey &func_key);
Expand All @@ -22,11 +24,20 @@ class Function : public Callable {

[[nodiscard]] std::string get_name() const override;

std::optional<std::string> &try_get_ast_serialization_data() {
const std::optional<std::string> &try_get_ast_serialization_data() const {
return ast_serialization_data_;
}

void set_ir_type(IRType type) {
ir_type_ = type;
}

IRType ir_type() const {
return ir_type_;
}

private:
IRType ir_type_{IRType::None};
std::optional<std::string> ast_serialization_data_; // For generating AST-Key
};

Expand Down
50 changes: 50 additions & 0 deletions taichi/transforms/compile_taichi_functions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/ir/statements.h"
#include "taichi/program/function.h"
#include "taichi/program/compile_config.h"

namespace taichi::lang {

class CompileTaichiFunctions : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;

explicit CompileTaichiFunctions(const CompileConfig &compile_config)
: compile_config_(compile_config) {
}

void visit(FuncCallStmt *stmt) override {
using IRType = Function::IRType;
auto *func = stmt->func;
const auto ir_type = func->ir_type();
if (ir_type != IRType::OptimizedIR) {
TI_ASSERT(ir_type == IRType::AST || ir_type == IRType::InitialIR);
func->set_ir_type(IRType::OptimizedIR);
irpass::compile_function(func->ir.get(), compile_config_, func,
/*autodiff_mode=*/AutodiffMode::kNone,
/*verbose=*/compile_config_.print_ir,
/*start_from_ast=*/ir_type == IRType::AST);
func->ir->accept(this);
}
}

static void run(IRNode *ir, const CompileConfig &compile_config) {
CompileTaichiFunctions ctf{compile_config};
ir->accept(&ctf);
}

private:
const CompileConfig &compile_config_;
};

namespace irpass {

void compile_taichi_functions(IRNode *ir, const CompileConfig &compile_config) {
TI_AUTO_PROF;
CompileTaichiFunctions::run(ir, compile_config);
}

} // namespace irpass

} // namespace taichi::lang
2 changes: 2 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ void compile_to_offloads(IRNode *ir,
print("Lowered");
}

irpass::compile_taichi_functions(ir, config);

irpass::eliminate_immutable_local_vars(ir);
print("Immutable local vars eliminated");

Expand Down

0 comments on commit f760fbb

Please sign in to comment.