Skip to content

Commit

Permalink
[llvm] [refactor] Split LLVMCompiledData of kernels and tasks (#6019)
Browse files Browse the repository at this point in the history
Related issue = #5511 
The compiler first compiles every offloaded task to a `LLVMCompiledTask`
which contains an LLVM module, the name of the offloaded task function
inside the module, and some extra information for linking. Then, The
compiler links the modules in the `LLVMCompiledTask`s, the runtime
modules and the struct modules used in the kernel together, and creates
a `LLVMCompiledKernel` that contain the linked LLVM module and the names
of the offloaded tasks.

Both `LLVMCompiledTask` and `LLVMCompiledKernel` need to store the
generated LLVM module and the names of the functions of the offloaded
tasks inside the module. `LLVMCompiledTask` also stores additional
information which will be used when linking like which SNodeTrees are
used and the sizes of TLS buffers used in parallel struct for.
<!--
Thank you for your contribution!

If it is your first time contributing to Taichi, please read our
Contributor Guidelines:
  https://docs.taichi-lang.org/docs/contributor_guide

- Please always prepend your PR title with tags such as [CUDA], [Lang],
[Doc], [Example]. For a complete list of valid PR tags, please check out
https://github.com/taichi-dev/taichi/blob/master/misc/prtags.json.
- Use upper-case tags (e.g., [Metal]) for PRs that change public APIs.
Otherwise, please use lower-case tags (e.g., [metal]).
- More details:
https://docs.taichi-lang.org/docs/contributor_guide#pr-title-format-and-tags

- Please fill in the issue number that this PR relates to.
- If your PR fixes the issue **completely**, use the `close` or `fixes`
prefix so that GitHub automatically closes the issue when the PR is
merged. For example,
    Related issue = close #2345
- If the PR does not belong to any existing issue, free to leave it
blank.
-->

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored Sep 13, 2022
1 parent 58d177f commit eee2e89
Show file tree
Hide file tree
Showing 23 changed files with 80 additions and 65 deletions.
18 changes: 9 additions & 9 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ std::unique_ptr<KernelCodeGen> KernelCodeGen::create(Arch arch,
}
#ifdef TI_WITH_LLVM

std::optional<LLVMCompiledData>
std::optional<LLVMCompiledKernel>
KernelCodeGen::maybe_read_compilation_from_cache(
const std::string &kernel_key) {
TI_AUTO_PROF;
Expand All @@ -79,13 +79,13 @@ KernelCodeGen::maybe_read_compilation_from_cache(
return {std::move(cache_data.compiled_data)};
}

void KernelCodeGen::cache_module(const std::string &kernel_key,
const LLVMCompiledData &data) {
void KernelCodeGen::cache_kernel(const std::string &kernel_key,
const LLVMCompiledKernel &data) {
get_llvm_program(prog)->cache_kernel(kernel_key, data,
infer_launch_args(kernel));
}

LLVMCompiledData KernelCodeGen::compile_kernel_to_module() {
LLVMCompiledKernel KernelCodeGen::compile_kernel_to_module() {
auto *llvm_prog = get_llvm_program(prog);
auto *tlctx = llvm_prog->get_llvm_context(kernel->arch);
auto &config = prog->config;
Expand All @@ -97,7 +97,7 @@ LLVMCompiledData KernelCodeGen::compile_kernel_to_module() {
if (res) {
TI_DEBUG("Create kernel '{}' from cache (key='{}')", kernel->get_name(),
kernel_key);
cache_module(kernel_key, *res);
cache_kernel(kernel_key, *res);
return std::move(*res);
}
}
Expand All @@ -110,7 +110,7 @@ LLVMCompiledData KernelCodeGen::compile_kernel_to_module() {
TI_ASSERT(block);

auto &offloads = block->statements;
std::vector<std::unique_ptr<LLVMCompiledData>> data(offloads.size());
std::vector<std::unique_ptr<LLVMCompiledTask>> data(offloads.size());
using TaskFunc = int32 (*)(void *);
std::vector<TaskFunc> task_funcs(offloads.size());
for (int i = 0; i < offloads.size(); i++) {
Expand All @@ -120,7 +120,7 @@ LLVMCompiledData KernelCodeGen::compile_kernel_to_module() {
irpass::analysis::clone(offloads[i].get(), offloads[i]->get_kernel());
irpass::re_id(offload.get());
auto new_data = this->compile_task(nullptr, offload->as<OffloadedStmt>());
data[i] = std::make_unique<LLVMCompiledData>(std::move(new_data));
data[i] = std::make_unique<LLVMCompiledTask>(std::move(new_data));
};
if (kernel->is_evaluator) {
compile_func();
Expand All @@ -135,7 +135,7 @@ LLVMCompiledData KernelCodeGen::compile_kernel_to_module() {

if (!kernel->is_evaluator) {
TI_DEBUG("Cache kernel '{}' (key='{}')", kernel->get_name(), kernel_key);
cache_module(kernel_key, linked);
cache_kernel(kernel_key, linked);
}
return linked;
}
Expand All @@ -147,7 +147,7 @@ ModuleToFunctionConverter::ModuleToFunctionConverter(
}

FunctionType ModuleToFunctionConverter::convert(const Kernel *kernel,
LLVMCompiledData data) const {
LLVMCompiledKernel data) const {
return convert(kernel->name, infer_launch_args(kernel), std::move(data));
}

Expand Down
14 changes: 7 additions & 7 deletions taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ class KernelCodeGen {
}

#ifdef TI_WITH_LLVM
virtual LLVMCompiledData compile_kernel_to_module();
virtual LLVMCompiledKernel compile_kernel_to_module();

virtual LLVMCompiledData compile_task(
virtual LLVMCompiledTask compile_task(
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr){TI_NOT_IMPLEMENTED}

std::optional<LLVMCompiledData> maybe_read_compilation_from_cache(
std::optional<LLVMCompiledKernel> maybe_read_compilation_from_cache(
const std::string &kernel_key);
void cache_module(const std::string &kernel_key,
const LLVMCompiledData &data);
void cache_kernel(const std::string &kernel_key,
const LLVMCompiledKernel &data);
#endif
};

Expand All @@ -57,10 +57,10 @@ class ModuleToFunctionConverter {

virtual FunctionType convert(const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
LLVMCompiledData data) const = 0;
LLVMCompiledKernel data) const = 0;

virtual FunctionType convert(const Kernel *kernel,
LLVMCompiledData data) const;
LLVMCompiledKernel data) const;

protected:
TaichiLLVMContext *tlctx_{nullptr};
Expand Down
8 changes: 3 additions & 5 deletions taichi/codegen/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ std::unique_ptr<TaskCodeGenLLVM> KernelCodeGenCPU::make_codegen_llvm(
FunctionType CPUModuleToFunctionConverter::convert(
const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
LLVMCompiledData data) const {
LLVMCompiledKernel data) const {
TI_AUTO_PROF;
auto jit_module = tlctx_->create_jit_module(std::move(data.module));
using TaskFunc = int32 (*)(void *);
Expand Down Expand Up @@ -271,7 +271,7 @@ FunctionType CPUModuleToFunctionConverter::convert(
};
}

LLVMCompiledData KernelCodeGenCPU::compile_task(
LLVMCompiledTask KernelCodeGenCPU::compile_task(
std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
TaskCodeGenCPU gen(kernel, stmt);
Expand All @@ -284,10 +284,8 @@ FunctionType KernelCodeGenCPU::compile_to_function() {
auto *llvm_prog = get_llvm_program(prog);
auto *tlctx = llvm_prog->get_llvm_context(kernel->arch);

LLVMCompiledData data = compile_kernel_to_module();

CPUModuleToFunctionConverter converter(
tlctx, get_llvm_program(prog)->get_runtime_executor());
return converter.convert(kernel, std::move(data));
return converter.convert(kernel, compile_kernel_to_module());
}
TLANG_NAMESPACE_END
4 changes: 2 additions & 2 deletions taichi/codegen/cpu/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class KernelCodeGenCPU : public KernelCodeGen {
bool supports_offline_cache() const override {
return true;
}
LLVMCompiledData compile_task(
LLVMCompiledTask compile_task(
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) override;

Expand All @@ -45,7 +45,7 @@ class CPUModuleToFunctionConverter : public ModuleToFunctionConverter {

FunctionType convert(const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
LLVMCompiledData data) const override;
LLVMCompiledKernel data) const override;
};

#endif
Expand Down
7 changes: 3 additions & 4 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ std::unique_ptr<TaskCodeGenLLVM> KernelCodeGenCUDA::make_codegen_llvm(
}
#endif // TI_WITH_LLVM

LLVMCompiledData KernelCodeGenCUDA::compile_task(
LLVMCompiledTask KernelCodeGenCUDA::compile_task(
std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
TaskCodeGenCUDA gen(kernel, stmt);
Expand All @@ -706,17 +706,16 @@ FunctionType KernelCodeGenCUDA::compile_to_function() {
auto *llvm_prog = get_llvm_program(prog);
auto *tlctx = llvm_prog->get_llvm_context(kernel->arch);

LLVMCompiledData data = compile_kernel_to_module();
CUDAModuleToFunctionConverter converter{tlctx,
llvm_prog->get_runtime_executor()};

return converter.convert(this->kernel, std::move(data));
return converter.convert(this->kernel, compile_kernel_to_module());
}

FunctionType CUDAModuleToFunctionConverter::convert(
const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
LLVMCompiledData data) const {
LLVMCompiledKernel data) const {
auto &mod = data.module;
auto &tasks = data.tasks;
#ifdef TI_WITH_CUDA
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class KernelCodeGenCUDA : public KernelCodeGen {
#ifdef TI_WITH_LLVM
static std::unique_ptr<TaskCodeGenLLVM> make_codegen_llvm(Kernel *kernel,
IRNode *ir);
LLVMCompiledData compile_task(
LLVMCompiledTask compile_task(
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) override;
#endif // TI_WITH_LLVM
Expand All @@ -39,7 +39,7 @@ class CUDAModuleToFunctionConverter : public ModuleToFunctionConverter {

FunctionType convert(const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
LLVMCompiledData data) const override;
LLVMCompiledKernel data) const override;
};

TLANG_NAMESPACE_END
8 changes: 6 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2647,7 +2647,7 @@ void TaskCodeGenLLVM::emit_to_module() {
ir->accept(this);
}

LLVMCompiledData TaskCodeGenLLVM::run_compilation() {
LLVMCompiledTask TaskCodeGenLLVM::run_compilation() {
// Final lowering

auto config = kernel->program->config;
Expand Down Expand Up @@ -2732,11 +2732,15 @@ void TaskCodeGenLLVM::visit(FuncCallStmt *stmt) {
}
}

LLVMCompiledData LLVMCompiledData::clone() const {
LLVMCompiledTask LLVMCompiledTask::clone() const {
return {tasks, llvm::CloneModule(*module), used_tree_ids,
struct_for_tls_sizes};
}

LLVMCompiledKernel LLVMCompiledKernel::clone() const {
return {tasks, llvm::CloneModule(*module)};
}

TLANG_NAMESPACE_END

#endif // #ifdef TI_WITH_LLVM
4 changes: 2 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
*
* After this call, `module` and `tasks` will be moved.
*
* @return LLVMCompiledData
* @return LLVMCompiledTask
*/
virtual LLVMCompiledData run_compilation();
virtual LLVMCompiledTask run_compilation();
// For debugging only
virtual llvm::Value *create_print(std::string tag,
DataType dt,
Expand Down
26 changes: 20 additions & 6 deletions taichi/codegen/llvm/llvm_compiled_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ class OffloadedTask {
TI_IO_DEF(name, block_dim, grid_dim);
};

struct LLVMCompiledData {
struct LLVMCompiledTask {
std::vector<OffloadedTask> tasks;
std::unique_ptr<llvm::Module> module{nullptr};
std::unordered_set<int> used_tree_ids;
std::unordered_set<int> struct_for_tls_sizes;
LLVMCompiledData() = default;
LLVMCompiledData(LLVMCompiledData &&) = default;
LLVMCompiledData &operator=(LLVMCompiledData &&) = default;
LLVMCompiledData(std::vector<OffloadedTask> tasks,
LLVMCompiledTask() = default;
LLVMCompiledTask(LLVMCompiledTask &&) = default;
LLVMCompiledTask &operator=(LLVMCompiledTask &&) = default;
LLVMCompiledTask(std::vector<OffloadedTask> tasks,
std::unique_ptr<llvm::Module> module,
std::unordered_set<int> used_tree_ids,
std::unordered_set<int> struct_for_tls_sizes)
Expand All @@ -38,7 +38,21 @@ struct LLVMCompiledData {
used_tree_ids(std::move(used_tree_ids)),
struct_for_tls_sizes(std::move(struct_for_tls_sizes)) {
}
LLVMCompiledData clone() const;
LLVMCompiledTask clone() const;
TI_IO_DEF(tasks);
};

struct LLVMCompiledKernel {
std::vector<OffloadedTask> tasks;
std::unique_ptr<llvm::Module> module{nullptr};
LLVMCompiledKernel() = default;
LLVMCompiledKernel(LLVMCompiledKernel &&) = default;
LLVMCompiledKernel &operator=(LLVMCompiledKernel &&) = default;
LLVMCompiledKernel(std::vector<OffloadedTask> tasks,
std::unique_ptr<llvm::Module> module)
: tasks(std::move(tasks)), module(std::move(module)) {
}
LLVMCompiledKernel clone() const;
TI_IO_DEF(tasks);
};

Expand Down
12 changes: 6 additions & 6 deletions taichi/codegen/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class TaskCodeGenWASM : public TaskCodeGenLLVM {
TI_ASSERT(!llvm::verifyFunction(*func, &llvm::errs()));
}

LLVMCompiledData run_compilation() override {
LLVMCompiledTask run_compilation() override {
// lower kernel
if (!kernel->lowered()) {
kernel->lower();
Expand All @@ -235,7 +235,7 @@ class TaskCodeGenWASM : public TaskCodeGenLLVM {
}
return func_name == offloaded_task_name;
});
LLVMCompiledData res;
LLVMCompiledTask res;
res.tasks.emplace_back(offloaded_task_name);
res.module = std::move(this->module);
return res;
Expand All @@ -255,7 +255,7 @@ FunctionType KernelCodeGenWASM::compile_to_function() {
};
}

LLVMCompiledData KernelCodeGenWASM::compile_task(
LLVMCompiledTask KernelCodeGenWASM::compile_task(
std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
kernel->offload_to_executable(ir);
Expand All @@ -281,14 +281,14 @@ LLVMCompiledData KernelCodeGenWASM::compile_task(
return {name_list, std::move(gen->module), {}, {}};
}

LLVMCompiledData KernelCodeGenWASM::compile_kernel_to_module() {
LLVMCompiledKernel KernelCodeGenWASM::compile_kernel_to_module() {
auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch);
if (!kernel->lowered()) {
kernel->lower(/*to_executable=*/false);
}
auto res = compile_task();
std::vector<std::unique_ptr<LLVMCompiledData>> data;
data.push_back(std::make_unique<LLVMCompiledData>(std::move(res)));
std::vector<std::unique_ptr<LLVMCompiledTask>> data;
data.push_back(std::make_unique<LLVMCompiledTask>(std::move(res)));
return tlctx->link_compiled_tasks(std::move(data));
}

Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/wasm/codegen_wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class KernelCodeGenWASM : public KernelCodeGen {
FunctionType compile_to_function() override;

#ifdef TI_WITH_LLVM
LLVMCompiledData compile_task(
LLVMCompiledTask compile_task(
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) override; // AOT Module Gen

LLVMCompiledData compile_kernel_to_module() override;
LLVMCompiledKernel compile_kernel_to_module() override;
#endif
};

Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cpu/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace taichi {
namespace lang {
namespace cpu {

LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
LLVMCompiledKernel AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = KernelCodeGenCPU(kernel);
return cgen.compile_kernel_to_module();
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cpu/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
}

private:
LLVMCompiledData compile_kernel(Kernel *kernel) override;
LLVMCompiledKernel compile_kernel(Kernel *kernel) override;
};

} // namespace cpu
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cuda/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace taichi {
namespace lang {
namespace cuda {

LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
LLVMCompiledKernel AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = KernelCodeGenCUDA(kernel);
return cgen.compile_kernel_to_module();
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cuda/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
}

private:
LLVMCompiledData compile_kernel(Kernel *kernel) override;
LLVMCompiledKernel compile_kernel(Kernel *kernel) override;
};

} // namespace cuda
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/llvm/llvm_aot_module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LlvmAotModuleBuilder : public AotModuleBuilder {

protected:
void add_per_backend(const std::string &identifier, Kernel *kernel) override;
virtual LLVMCompiledData compile_kernel(Kernel *kernel) = 0;
virtual LLVMCompiledKernel compile_kernel(Kernel *kernel) = 0;

void add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
Expand Down
6 changes: 3 additions & 3 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,9 @@ TaichiLLVMContext::ThreadLocalData::~ThreadLocalData() {
thread_safe_llvm_context.reset();
}

LLVMCompiledData TaichiLLVMContext::link_compiled_tasks(
std::vector<std::unique_ptr<LLVMCompiledData>> data_list) {
LLVMCompiledData linked;
LLVMCompiledKernel TaichiLLVMContext::link_compiled_tasks(
std::vector<std::unique_ptr<LLVMCompiledTask>> data_list) {
LLVMCompiledKernel linked;
std::unordered_set<int> used_tree_ids;
std::unordered_set<int> tls_sizes;
std::unordered_set<std::string> offloaded_names;
Expand Down
Loading

0 comments on commit eee2e89

Please sign in to comment.