From e3eb6be73ee729ea05a207f97a71a8195f5d8771 Mon Sep 17 00:00:00 2001 From: Lin Jiang Date: Thu, 8 Sep 2022 17:53:04 +0800 Subject: [PATCH] [llvm] [refactor] Remove the use of vector with size=1 (#6002) Related issue = #5511 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/codegen/codegen.cpp | 38 +++-- taichi/codegen/codegen.h | 16 +-- taichi/codegen/cpu/codegen_cpu.cpp | 20 ++- taichi/codegen/cpu/codegen_cpu.h | 2 +- taichi/codegen/cuda/codegen_cuda.cpp | 8 +- taichi/codegen/cuda/codegen_cuda.h | 2 +- taichi/codegen/llvm/llvm_compiled_data.h | 1 + taichi/codegen/wasm/codegen_wasm.cpp | 9 +- taichi/codegen/wasm/codegen_wasm.h | 2 +- .../runtime/cpu/aot_module_builder_impl.cpp | 2 +- taichi/runtime/cpu/aot_module_loader_impl.cpp | 2 +- .../runtime/cuda/aot_module_builder_impl.cpp | 2 +- .../runtime/cuda/aot_module_loader_impl.cpp | 2 +- .../runtime/llvm/llvm_aot_module_builder.cpp | 2 +- taichi/runtime/llvm/llvm_context.cpp | 8 +- taichi/runtime/llvm/llvm_context.h | 2 +- taichi/runtime/llvm/llvm_offline_cache.cpp | 134 +++++++----------- taichi/runtime/llvm/llvm_offline_cache.h | 14 +- .../program_impls/llvm/llvm_program.cpp | 13 +- .../runtime/program_impls/llvm/llvm_program.h | 2 +- .../runtime/wasm/aot_module_builder_impl.cpp | 4 +- tests/cpp/codegen/refine_coordinates_test.cpp | 2 +- tests/cpp/llvm/llvm_offline_cache_test.cpp | 18 ++- 23 files changed, 128 insertions(+), 177 deletions(-) diff --git a/taichi/codegen/codegen.cpp b/taichi/codegen/codegen.cpp index 0c945d9a36bb9..e9956154b9ad9 100644 --- a/taichi/codegen/codegen.cpp +++ b/taichi/codegen/codegen.cpp @@ -57,15 +57,15 @@ std::unique_ptr KernelCodeGen::create(Arch arch, } #ifdef TI_WITH_LLVM -bool KernelCodeGen::maybe_read_compilation_from_cache( - const std::string &kernel_key, - std::vector &data) { +std::optional +KernelCodeGen::maybe_read_compilation_from_cache( + const std::string &kernel_key) { TI_AUTO_PROF; const auto &config = prog->config; auto *llvm_prog = get_llvm_program(prog); const auto &reader = llvm_prog->get_cache_reader(); if (!reader) { - return false; + return std::nullopt; } LlvmOfflineCache::KernelCacheData cache_data; @@ -73,20 +73,19 @@ bool KernelCodeGen::maybe_read_compilation_from_cache( auto &llvm_ctx = *tlctx->get_this_thread_context(); if (!reader->get_kernel_cache(cache_data, kernel_key, llvm_ctx)) { - return false; + return std::nullopt; } - data.swap(cache_data.compiled_data_list); kernel->mark_as_from_cache(); - return true; + return {std::move(cache_data.compiled_data)}; } void KernelCodeGen::cache_module(const std::string &kernel_key, - const std::vector &data) { + const LLVMCompiledData &data) { get_llvm_program(prog)->cache_kernel(kernel_key, data, infer_launch_args(kernel)); } -std::vector KernelCodeGen::compile_kernel_to_module() { +LLVMCompiledData KernelCodeGen::compile_kernel_to_module() { auto *llvm_prog = get_llvm_program(prog); auto *tlctx = llvm_prog->get_llvm_context(kernel->arch); auto &config = prog->config; @@ -94,14 +93,12 @@ std::vector KernelCodeGen::compile_kernel_to_module() { kernel->set_kernel_key_for_cache(kernel_key); if (config.offline_cache && this->supports_offline_cache() && !kernel->is_evaluator) { - std::vector res; - const bool ok = maybe_read_compilation_from_cache(kernel_key, res); - if (ok) { + auto res = maybe_read_compilation_from_cache(kernel_key); + if (res) { TI_DEBUG("Create kernel '{}' from cache (key='{}')", kernel->get_name(), kernel_key); - cache_module(kernel_key, res); - TI_ASSERT(res.size() == 1); - return res; + cache_module(kernel_key, *res); + return std::move(*res); } } if (!kernel->lowered()) { @@ -135,14 +132,12 @@ std::vector KernelCodeGen::compile_kernel_to_module() { worker.flush(); } auto linked = tlctx->link_compiled_tasks(std::move(data)); - std::vector linked_data; - linked_data.push_back(std::move(*linked)); if (!kernel->is_evaluator) { TI_DEBUG("Cache kernel '{}' (key='{}')", kernel->get_name(), kernel_key); - cache_module(kernel_key, linked_data); + cache_module(kernel_key, linked); } - return linked_data; + return linked; } ModuleToFunctionConverter::ModuleToFunctionConverter( @@ -151,9 +146,8 @@ ModuleToFunctionConverter::ModuleToFunctionConverter( : tlctx_(tlctx), executor_(executor) { } -FunctionType ModuleToFunctionConverter::convert( - const Kernel *kernel, - std::vector &&data) const { +FunctionType ModuleToFunctionConverter::convert(const Kernel *kernel, + LLVMCompiledData data) const { return convert(kernel->name, infer_launch_args(kernel), std::move(data)); } diff --git a/taichi/codegen/codegen.h b/taichi/codegen/codegen.h index 53233ac45d734..bff0eb3729aa4 100644 --- a/taichi/codegen/codegen.h +++ b/taichi/codegen/codegen.h @@ -33,18 +33,16 @@ class KernelCodeGen { } #ifdef TI_WITH_LLVM - virtual std::vector compile_kernel_to_module(); + virtual LLVMCompiledData compile_kernel_to_module(); virtual LLVMCompiledData compile_task( std::unique_ptr &&module = nullptr, - OffloadedStmt *stmt = nullptr) { - TI_NOT_IMPLEMENTED - } + OffloadedStmt *stmt = nullptr){TI_NOT_IMPLEMENTED} - bool maybe_read_compilation_from_cache(const std::string &kernel_key, - std::vector &data); + std::optional maybe_read_compilation_from_cache( + const std::string &kernel_key); void cache_module(const std::string &kernel_key, - const std::vector &data); + const LLVMCompiledData &data); #endif }; @@ -59,10 +57,10 @@ class ModuleToFunctionConverter { virtual FunctionType convert(const std::string &kernel_name, const std::vector &args, - std::vector &&data) const = 0; + LLVMCompiledData data) const = 0; virtual FunctionType convert(const Kernel *kernel, - std::vector &&data) const; + LLVMCompiledData data) const; protected: TaichiLLVMContext *tlctx_{nullptr}; diff --git a/taichi/codegen/cpu/codegen_cpu.cpp b/taichi/codegen/cpu/codegen_cpu.cpp index ca4742bf51e6a..986eb43706b14 100644 --- a/taichi/codegen/cpu/codegen_cpu.cpp +++ b/taichi/codegen/cpu/codegen_cpu.cpp @@ -234,19 +234,17 @@ std::unique_ptr KernelCodeGenCPU::make_codegen_llvm( FunctionType CPUModuleToFunctionConverter::convert( const std::string &kernel_name, const std::vector &args, - std::vector &&data) const { + LLVMCompiledData data) const { TI_AUTO_PROF; - auto jit_module = tlctx_->create_jit_module(std::move(data.back().module)); + auto jit_module = tlctx_->create_jit_module(std::move(data.module)); using TaskFunc = int32 (*)(void *); std::vector task_funcs; - task_funcs.reserve(data.size()); - for (auto &datum : data) { - for (auto &task : datum.tasks) { - auto *func_ptr = jit_module->lookup_function(task.name); - TI_ASSERT_INFO(func_ptr, "Offloaded datum function {} not found", - task.name); - task_funcs.push_back((TaskFunc)(func_ptr)); - } + task_funcs.reserve(data.tasks.size()); + for (auto &task : data.tasks) { + auto *func_ptr = jit_module->lookup_function(task.name); + TI_ASSERT_INFO(func_ptr, "Offloaded datum function {} not found", + task.name); + task_funcs.push_back((TaskFunc)(func_ptr)); } // Do NOT capture `this`... return [executor = this->executor_, args, kernel_name, @@ -286,7 +284,7 @@ FunctionType KernelCodeGenCPU::compile_to_function() { auto *llvm_prog = get_llvm_program(prog); auto *tlctx = llvm_prog->get_llvm_context(kernel->arch); - std::vector data = compile_kernel_to_module(); + LLVMCompiledData data = compile_kernel_to_module(); CPUModuleToFunctionConverter converter( tlctx, get_llvm_program(prog)->get_runtime_executor()); diff --git a/taichi/codegen/cpu/codegen_cpu.h b/taichi/codegen/cpu/codegen_cpu.h index 71991def84566..880029149468c 100644 --- a/taichi/codegen/cpu/codegen_cpu.h +++ b/taichi/codegen/cpu/codegen_cpu.h @@ -45,7 +45,7 @@ class CPUModuleToFunctionConverter : public ModuleToFunctionConverter { FunctionType convert(const std::string &kernel_name, const std::vector &args, - std::vector &&data) const override; + LLVMCompiledData data) const override; }; #endif diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 3572e1319f383..87a2f657b2209 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -704,7 +704,7 @@ FunctionType KernelCodeGenCUDA::compile_to_function() { auto *llvm_prog = get_llvm_program(prog); auto *tlctx = llvm_prog->get_llvm_context(kernel->arch); - std::vector data = compile_kernel_to_module(); + LLVMCompiledData data = compile_kernel_to_module(); CUDAModuleToFunctionConverter converter{tlctx, llvm_prog->get_runtime_executor()}; @@ -714,9 +714,9 @@ FunctionType KernelCodeGenCUDA::compile_to_function() { FunctionType CUDAModuleToFunctionConverter::convert( const std::string &kernel_name, const std::vector &args, - std::vector &&data) const { - auto &mod = data[0].module; - auto &tasks = data[0].tasks; + LLVMCompiledData data) const { + auto &mod = data.module; + auto &tasks = data.tasks; #ifdef TI_WITH_CUDA for (const auto &task : tasks) { llvm::Function *func = mod->getFunction(task.name); diff --git a/taichi/codegen/cuda/codegen_cuda.h b/taichi/codegen/cuda/codegen_cuda.h index 8cd80a4ee2d22..54b618b0ad997 100644 --- a/taichi/codegen/cuda/codegen_cuda.h +++ b/taichi/codegen/cuda/codegen_cuda.h @@ -39,7 +39,7 @@ class CUDAModuleToFunctionConverter : public ModuleToFunctionConverter { FunctionType convert(const std::string &kernel_name, const std::vector &args, - std::vector &&data) const override; + LLVMCompiledData data) const override; }; TLANG_NAMESPACE_END diff --git a/taichi/codegen/llvm/llvm_compiled_data.h b/taichi/codegen/llvm/llvm_compiled_data.h index 59f9a33c3bb42..a98fa7fb0e719 100644 --- a/taichi/codegen/llvm/llvm_compiled_data.h +++ b/taichi/codegen/llvm/llvm_compiled_data.h @@ -28,6 +28,7 @@ struct LLVMCompiledData { std::unordered_set struct_for_tls_sizes; LLVMCompiledData() = default; LLVMCompiledData(LLVMCompiledData &&) = default; + LLVMCompiledData &operator=(LLVMCompiledData &&) = default; LLVMCompiledData(std::vector tasks, std::unique_ptr module, std::unordered_set used_tree_ids, diff --git a/taichi/codegen/wasm/codegen_wasm.cpp b/taichi/codegen/wasm/codegen_wasm.cpp index bd3104ecb8173..e5c3934366555 100644 --- a/taichi/codegen/wasm/codegen_wasm.cpp +++ b/taichi/codegen/wasm/codegen_wasm.cpp @@ -244,7 +244,7 @@ class TaskCodeGenWASM : public TaskCodeGenLLVM { FunctionType KernelCodeGenWASM::compile_to_function() { TI_AUTO_PROF - auto linked = std::move(compile_kernel_to_module()[0]); + auto linked = compile_kernel_to_module(); auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch); tlctx->create_jit_module(std::move(linked.module)); auto kernel_symbol = tlctx->lookup_function_pointer(linked.tasks[0].name); @@ -281,7 +281,7 @@ LLVMCompiledData KernelCodeGenWASM::compile_task( return {name_list, std::move(gen->module), {}, {}}; } -std::vector KernelCodeGenWASM::compile_kernel_to_module() { +LLVMCompiledData KernelCodeGenWASM::compile_kernel_to_module() { auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch); if (!kernel->lowered()) { kernel->lower(/*to_executable=*/false); @@ -289,10 +289,7 @@ std::vector KernelCodeGenWASM::compile_kernel_to_module() { auto res = compile_task(); std::vector> data; data.push_back(std::make_unique(std::move(res))); - auto linked = tlctx->link_compiled_tasks(std::move(data)); - std::vector ret; - ret.push_back(std::move(*linked)); - return ret; + return tlctx->link_compiled_tasks(std::move(data)); } } // namespace lang diff --git a/taichi/codegen/wasm/codegen_wasm.h b/taichi/codegen/wasm/codegen_wasm.h index 91ec5aaf6a648..6ae99ac291b71 100644 --- a/taichi/codegen/wasm/codegen_wasm.h +++ b/taichi/codegen/wasm/codegen_wasm.h @@ -24,7 +24,7 @@ class KernelCodeGenWASM : public KernelCodeGen { std::unique_ptr &&module = nullptr, OffloadedStmt *stmt = nullptr) override; // AOT Module Gen - std::vector compile_kernel_to_module() override; + LLVMCompiledData compile_kernel_to_module() override; #endif }; diff --git a/taichi/runtime/cpu/aot_module_builder_impl.cpp b/taichi/runtime/cpu/aot_module_builder_impl.cpp index daf89298f3d76..15b48e540855a 100644 --- a/taichi/runtime/cpu/aot_module_builder_impl.cpp +++ b/taichi/runtime/cpu/aot_module_builder_impl.cpp @@ -11,7 +11,7 @@ namespace cpu { LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) { auto cgen = KernelCodeGenCPU(kernel); - return std::move(cgen.compile_kernel_to_module()[0]); + return cgen.compile_kernel_to_module(); } } // namespace cpu diff --git a/taichi/runtime/cpu/aot_module_loader_impl.cpp b/taichi/runtime/cpu/aot_module_loader_impl.cpp index 8345d373a07ec..bfe63e4370fa8 100644 --- a/taichi/runtime/cpu/aot_module_loader_impl.cpp +++ b/taichi/runtime/cpu/aot_module_loader_impl.cpp @@ -25,7 +25,7 @@ class AotModuleImpl : public LlvmAotModule { CPUModuleToFunctionConverter converter{tlctx, executor_}; return converter.convert(name, loaded.args, - std::move(loaded.compiled_data_list)); + std::move(loaded.compiled_data)); } std::unique_ptr make_new_kernel_template( diff --git a/taichi/runtime/cuda/aot_module_builder_impl.cpp b/taichi/runtime/cuda/aot_module_builder_impl.cpp index 0d7431d9874ad..42a1ab3e03db0 100644 --- a/taichi/runtime/cuda/aot_module_builder_impl.cpp +++ b/taichi/runtime/cuda/aot_module_builder_impl.cpp @@ -11,7 +11,7 @@ namespace cuda { LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) { auto cgen = KernelCodeGenCUDA(kernel); - return std::move(cgen.compile_kernel_to_module()[0]); + return cgen.compile_kernel_to_module(); } } // namespace cuda diff --git a/taichi/runtime/cuda/aot_module_loader_impl.cpp b/taichi/runtime/cuda/aot_module_loader_impl.cpp index 42dd4bc961fbc..1c0a9285cf762 100644 --- a/taichi/runtime/cuda/aot_module_loader_impl.cpp +++ b/taichi/runtime/cuda/aot_module_loader_impl.cpp @@ -25,7 +25,7 @@ class AotModuleImpl : public LlvmAotModule { CUDAModuleToFunctionConverter converter{tlctx, executor_}; return converter.convert(name, loaded.args, - std::move(loaded.compiled_data_list)); + std::move(loaded.compiled_data)); } std::unique_ptr make_new_kernel_template( diff --git a/taichi/runtime/llvm/llvm_aot_module_builder.cpp b/taichi/runtime/llvm/llvm_aot_module_builder.cpp index 9860a93247c7e..c8e61c4484325 100644 --- a/taichi/runtime/llvm/llvm_aot_module_builder.cpp +++ b/taichi/runtime/llvm/llvm_aot_module_builder.cpp @@ -22,7 +22,7 @@ void LlvmAotModuleBuilder::add_per_backend(const std::string &identifier, auto compiled = compile_kernel(kernel); LlvmOfflineCache::KernelCacheData kcache; kcache.kernel_key = identifier; - kcache.compiled_data_list.push_back(std::move(compiled)); + kcache.compiled_data = std::move(compiled); kcache.args = infer_launch_args(kernel); kcache.last_used_at = std::time(nullptr); kcache.created_at = std::time(nullptr); diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index 31c24b1305038..52eeb803857d6 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -889,9 +889,9 @@ TaichiLLVMContext::ThreadLocalData::~ThreadLocalData() { thread_safe_llvm_context.reset(); } -std::unique_ptr TaichiLLVMContext::link_compiled_tasks( +LLVMCompiledData TaichiLLVMContext::link_compiled_tasks( std::vector> data_list) { - auto linked = std::make_unique(); + LLVMCompiledData linked; std::unordered_set used_tree_ids; std::unordered_set tls_sizes; std::unordered_set offloaded_names; @@ -906,7 +906,7 @@ std::unique_ptr TaichiLLVMContext::link_compiled_tasks( } for (auto &task : datum->tasks) { offloaded_names.insert(task.name); - linked->tasks.push_back(std::move(task)); + linked.tasks.push_back(std::move(task)); } linker.linkInModule(clone_module_to_context( datum->module.get(), linking_context_data->llvm_context)); @@ -927,7 +927,7 @@ std::unique_ptr TaichiLLVMContext::link_compiled_tasks( eliminate_unused_functions(mod.get(), [&](std::string func_name) -> bool { return offloaded_names.count(func_name); }); - linked->module = std::move(mod); + linked.module = std::move(mod); return linked; } diff --git a/taichi/runtime/llvm/llvm_context.h b/taichi/runtime/llvm/llvm_context.h index a4dda7ce846eb..de915d0b77f63 100644 --- a/taichi/runtime/llvm/llvm_context.h +++ b/taichi/runtime/llvm/llvm_context.h @@ -142,7 +142,7 @@ class TaichiLLVMContext { static std::string get_struct_for_func_name(int tls_size); - std::unique_ptr link_compiled_tasks( + LLVMCompiledData link_compiled_tasks( std::vector> data_list); private: diff --git a/taichi/runtime/llvm/llvm_offline_cache.cpp b/taichi/runtime/llvm/llvm_offline_cache.cpp index e61a1640ede89..ed4dbab2160dc 100644 --- a/taichi/runtime/llvm/llvm_offline_cache.cpp +++ b/taichi/runtime/llvm/llvm_offline_cache.cpp @@ -95,11 +95,9 @@ struct CacheCleanerUtils { const CacheCleanerConfig &config, const KernelMetaData &kernel_meta) { std::vector result; - for (int i = 0; i < kernel_meta.compiled_data_list.size(); i++) { - for (const auto &f : get_possible_llvm_cache_filename_by_key( - kernel_meta.kernel_key + "." + std::to_string(i))) { - result.push_back(f); - } + for (const auto &f : + get_possible_llvm_cache_filename_by_key(kernel_meta.kernel_key)) { + result.push_back(f); } return result; } @@ -196,21 +194,16 @@ bool LlvmOfflineCacheFileReader::get_kernel_cache( } auto &kernel_data = itr->second; - for (int i = 0; i < kernel_data.compiled_data_list.size(); i++) { - auto &data = kernel_data.compiled_data_list[i]; + auto &data = kernel_data.compiled_data; + if (!data.module) { + std::string filename_prefix = taichi::join_path(path_, key); + data.module = load_module(filename_prefix, key, llvm_ctx); if (!data.module) { - std::string filename_prefix = - taichi::join_path(path_, key + "." + std::to_string(i)); - data.module = load_module(filename_prefix, key, llvm_ctx); - if (!data.module) { - data_.kernels.erase(itr); - return false; // Must return - } + data_.kernels.erase(itr); + return false; // Must return } - res.compiled_data_list.emplace_back( - data.tasks, llvm::CloneModule(*data.module), data.used_tree_ids, - data.struct_for_tls_sizes); } + res.compiled_data = data.clone(); kernel_data.last_used_at = std::time(nullptr); @@ -220,27 +213,21 @@ bool LlvmOfflineCacheFileReader::get_kernel_cache( res.args = kernel_data.args; // Verify the `res: LlvmOfflineCache::KernelCacheData` - bool verified_all = true; - const auto &compiled_data_list = res.compiled_data_list; - for (std::size_t i = 0; i < compiled_data_list.size(); ++i) { - const auto &data = compiled_data_list[i]; - const auto &tasks = data.tasks; - bool verified = true; - for (const auto &t : tasks) { - if (data.module->getFunction(t.name) == nullptr) { - verified = false; - verified_all = false; - } + const auto &compiled_data = res.compiled_data; + const auto &tasks = compiled_data.tasks; + bool verified = true; + for (const auto &t : tasks) { + if (compiled_data.module->getFunction(t.name) == nullptr) { + verified = false; } - if (!verified) { - for (const auto &f : get_possible_llvm_cache_filename_by_key( - key + "." + std::to_string(i))) { - taichi::remove(taichi::join_path(path_, f)); - } + } + if (!verified) { + for (const auto &f : get_possible_llvm_cache_filename_by_key(key)) { + taichi::remove(taichi::join_path(path_, f)); } } - return verified_all; + return verified; } std::unique_ptr LlvmOfflineCacheFileReader::load_module( @@ -291,33 +278,28 @@ void LlvmOfflineCacheFileWriter::dump(const std::string &path, std::size_t size = 0; // bytes std::string filename_prefix = taichi::join_path(path, k); { - mangle_offloaded_task_name(k, v.compiled_data_list); - for (int i = 0; i < v.compiled_data_list.size(); i++) { - auto &data = v.compiled_data_list[i]; - auto *mod = data.module.get(); - TI_ASSERT(mod != nullptr); - std::string suffix = "." + std::to_string(i); - if (format & Format::LL) { - std::string filename = filename_prefix + suffix + ".ll"; - if (try_lock_with_file(filename)) { // Not exists - size += - write_llvm_module(filename, [mod](llvm::raw_os_ostream &os) { - mod->print(os, /*AAW=*/nullptr); - }); - } else { - TI_DEBUG("Cache file {} exists", filename); - } + mangle_offloaded_task_name(k, v.compiled_data); + auto &data = v.compiled_data; + auto *mod = data.module.get(); + TI_ASSERT(mod != nullptr); + if (format & Format::LL) { + std::string filename = filename_prefix + ".ll"; + if (try_lock_with_file(filename)) { // Not exists + size += write_llvm_module(filename, [mod](llvm::raw_os_ostream &os) { + mod->print(os, /*AAW=*/nullptr); + }); + } else { + TI_DEBUG("Cache file {} exists", filename); } - if (format & Format::BC) { - std::string filename = filename_prefix + suffix + ".bc"; - if (try_lock_with_file(filename)) { // Not exists - size += - write_llvm_module(filename, [mod](llvm::raw_os_ostream &os) { - llvm::WriteBitcodeToFile(*mod, os); - }); - } else { - TI_DEBUG("Cache file {} exists", filename); - } + } + if (format & Format::BC) { + std::string filename = filename_prefix + ".bc"; + if (try_lock_with_file(filename)) { // Not exists + size += write_llvm_module(filename, [mod](llvm::raw_os_ostream &os) { + llvm::WriteBitcodeToFile(*mod, os); + }); + } else { + TI_DEBUG("Cache file {} exists", filename); } } } @@ -402,21 +384,19 @@ void LlvmOfflineCacheFileWriter::merge_with(LlvmOfflineCache &&data) { void LlvmOfflineCacheFileWriter::mangle_offloaded_task_name( const std::string &kernel_key, - std::vector &compiled_data_list) { + LLVMCompiledData &compiled_data) { if (!mangled_) { - for (auto &e : compiled_data_list) { - for (auto &offload : e.tasks) { - std::string mangled_name = - offline_cache::mangle_name(offload.name, kernel_key); - TI_DEBUG( - "Mangle offloaded-task from internal name '{}' to offline cache " - "key '{}'", - offload.name, mangled_name); - auto func = e.module->getFunction(offload.name); - TI_ASSERT(func != nullptr); - func->setName(mangled_name); - offload.name = mangled_name; - } + for (auto &offload : compiled_data.tasks) { + std::string mangled_name = + offline_cache::mangle_name(offload.name, kernel_key); + TI_DEBUG( + "Mangle offloaded-task from internal name '{}' to offline cache " + "key '{}'", + offload.name, mangled_name); + auto func = compiled_data.module->getFunction(offload.name); + TI_ASSERT(func != nullptr); + func->setName(mangled_name); + offload.name = mangled_name; } } } @@ -439,11 +419,7 @@ void LlvmOfflineCacheFileWriter::clean_cache(const std::string &path, LlvmOfflineCache::KernelCacheData LlvmOfflineCache::KernelCacheData::clone() const { - std::vector new_data_list; - for (const auto &data : compiled_data_list) { - new_data_list.push_back(data.clone()); - } - return {kernel_key, args, std::move(new_data_list)}; + return {kernel_key, args, compiled_data.clone()}; } } // namespace lang } // namespace taichi diff --git a/taichi/runtime/llvm/llvm_offline_cache.h b/taichi/runtime/llvm/llvm_offline_cache.h index 955a9b6aae185..0499fad6d6b43 100644 --- a/taichi/runtime/llvm/llvm_offline_cache.h +++ b/taichi/runtime/llvm/llvm_offline_cache.h @@ -25,7 +25,7 @@ struct LlvmOfflineCache { struct KernelCacheData { std::string kernel_key; std::vector args; - std::vector compiled_data_list; + LLVMCompiledData compiled_data; // For cache cleaning std::size_t size{0}; // byte @@ -39,12 +39,7 @@ struct LlvmOfflineCache { KernelCacheData clone() const; - TI_IO_DEF(kernel_key, - args, - compiled_data_list, - size, - created_at, - last_used_at); + TI_IO_DEF(kernel_key, args, compiled_data, size, created_at, last_used_at); }; struct FieldCacheData { @@ -175,9 +170,8 @@ class LlvmOfflineCacheFileWriter { private: void merge_with(LlvmOfflineCache &&data); - void mangle_offloaded_task_name( - const std::string &kernel_key, - std::vector &compiled_data_list); + void mangle_offloaded_task_name(const std::string &kernel_key, + LLVMCompiledData &compiled_data); LlvmOfflineCache data_; bool mangled_{false}; diff --git a/taichi/runtime/program_impls/llvm/llvm_program.cpp b/taichi/runtime/program_impls/llvm/llvm_program.cpp index b058e545fe5db..8bce3d2b1d9a7 100644 --- a/taichi/runtime/program_impls/llvm/llvm_program.cpp +++ b/taichi/runtime/program_impls/llvm/llvm_program.cpp @@ -123,20 +123,15 @@ std::unique_ptr LlvmProgramImpl::make_aot_kernel(Kernel &kernel) { std::move(compiled_kernel)); } -void LlvmProgramImpl::cache_kernel( - const std::string &kernel_key, - const std::vector &data_list, - std::vector &&args) { +void LlvmProgramImpl::cache_kernel(const std::string &kernel_key, + const LLVMCompiledData &data, + std::vector &&args) { if (cache_data_->kernels.find(kernel_key) != cache_data_->kernels.end()) { return; } auto &kernel_cache = cache_data_->kernels[kernel_key]; kernel_cache.kernel_key = kernel_key; - for (const auto &data : data_list) { - kernel_cache.compiled_data_list.emplace_back( - data.tasks, llvm::CloneModule(*data.module), data.used_tree_ids, - data.struct_for_tls_sizes); - } + kernel_cache.compiled_data = data.clone(); kernel_cache.args = std::move(args); kernel_cache.created_at = std::time(nullptr); kernel_cache.last_used_at = std::time(nullptr); diff --git a/taichi/runtime/program_impls/llvm/llvm_program.h b/taichi/runtime/program_impls/llvm/llvm_program.h index 2da386bc631fa..b48a22054439c 100644 --- a/taichi/runtime/program_impls/llvm/llvm_program.h +++ b/taichi/runtime/program_impls/llvm/llvm_program.h @@ -51,7 +51,7 @@ class LlvmProgramImpl : public ProgramImpl { void materialize_snode_tree(SNodeTree *tree, uint64 *result_buffer) override; void cache_kernel(const std::string &kernel_key, - const std::vector &data_list, + const LLVMCompiledData &data, std::vector &&args); ; diff --git a/taichi/runtime/wasm/aot_module_builder_impl.cpp b/taichi/runtime/wasm/aot_module_builder_impl.cpp index c0bcacfb357e9..5fb07369ee5fc 100644 --- a/taichi/runtime/wasm/aot_module_builder_impl.cpp +++ b/taichi/runtime/wasm/aot_module_builder_impl.cpp @@ -39,10 +39,10 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, auto module_info = KernelCodeGenWASM(kernel, nullptr).compile_kernel_to_module(); if (module_) { - llvm::Linker::linkModules(*module_, std::move(module_info[0].module), + llvm::Linker::linkModules(*module_, std::move(module_info.module), llvm::Linker::OverrideFromSrc); } else { - module_ = std::move(module_info[0].module); + module_ = std::move(module_info.module); } } diff --git a/tests/cpp/codegen/refine_coordinates_test.cpp b/tests/cpp/codegen/refine_coordinates_test.cpp index 02425dce4e4d4..0ca338b2088a4 100644 --- a/tests/cpp/codegen/refine_coordinates_test.cpp +++ b/tests/cpp/codegen/refine_coordinates_test.cpp @@ -40,7 +40,7 @@ class InvokeRefineCoordinatesBuilder : public LLVMModuleBuilder { std::vector> data_list; data_list.push_back(std::make_unique(std::move(data))); auto linked_data = tlctx->link_compiled_tasks(std::move(data_list)); - auto *jit = tlctx->create_jit_module(std::move(linked_data->module)); + auto *jit = tlctx->create_jit_module(std::move(linked_data.module)); auto *fn = jit->lookup_function(kFuncName); return reinterpret_cast(fn); } diff --git a/tests/cpp/llvm/llvm_offline_cache_test.cpp b/tests/cpp/llvm/llvm_offline_cache_test.cpp index ef622db1e6414..8572ea4dd13f9 100644 --- a/tests/cpp/llvm/llvm_offline_cache_test.cpp +++ b/tests/cpp/llvm/llvm_offline_cache_test.cpp @@ -105,10 +105,8 @@ TEST_P(LlvmOfflineCacheTest, ReadWrite) { task.block_dim = kBlockDim; task.grid_dim = kGridDim; tasks.push_back(task); - LLVMCompiledData data; - data.tasks = tasks; - data.module = make_module(*llvm_ctx); - kcache.compiled_data_list.push_back(std::move(data)); + kcache.compiled_data.tasks = tasks; + kcache.compiled_data.module = make_module(*llvm_ctx); kcache.args = arg_infos; writer.add_kernel_cache(kKernelName, std::move(kcache)); writer.set_no_mangle(); @@ -122,14 +120,14 @@ TEST_P(LlvmOfflineCacheTest, ReadWrite) { const bool ok = reader->get_kernel_cache(kcache, kKernelName, *llvm_ctx); ASSERT_TRUE(ok); EXPECT_EQ(kcache.kernel_key, kKernelName); - EXPECT_EQ(kcache.compiled_data_list[0].tasks.size(), 1); - const auto &task0 = kcache.compiled_data_list[0].tasks.front(); + EXPECT_EQ(kcache.compiled_data.tasks.size(), 1); + const auto &task0 = kcache.compiled_data.tasks.front(); EXPECT_EQ(task0.name, kTaskName); - ASSERT_NE(kcache.compiled_data_list[0].module, nullptr); - kcache.compiled_data_list[0].module->dump(); - auto jit_module = tlctx_->create_jit_module( - std::move(kcache.compiled_data_list[0].module)); + ASSERT_NE(kcache.compiled_data.module, nullptr); + kcache.compiled_data.module->dump(); + auto jit_module = + tlctx_->create_jit_module(std::move(kcache.compiled_data.module)); using FuncType = int (*)(int, int); FuncType my_add = (FuncType)jit_module->lookup_function(kTaskName); const auto res = my_add(40, 2);