From c4174e0059e8a943809d8ce3a3c679147523d354 Mon Sep 17 00:00:00 2001 From: PGZXB Date: Mon, 26 Sep 2022 18:14:44 +0800 Subject: [PATCH] [metal] Maintain a print string table per kernel (#6160) Issue: #4401 * Fix a potential bug in metal AOT * Prepare for implementing offline cache on metal Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/codegen/metal/codegen_metal.cpp | 21 ++--- taichi/codegen/metal/codegen_metal.h | 1 - .../runtime/metal/aot_module_builder_impl.cpp | 4 +- .../runtime/metal/aot_module_builder_impl.h | 1 - .../runtime/metal/aot_module_loader_impl.cpp | 4 +- taichi/runtime/metal/kernel_manager.cpp | 81 +++++++------------ taichi/runtime/metal/kernel_manager.h | 6 +- taichi/runtime/metal/kernel_utils.cpp | 2 +- taichi/runtime/metal/kernel_utils.h | 7 +- 9 files changed, 47 insertions(+), 80 deletions(-) diff --git a/taichi/codegen/metal/codegen_metal.cpp b/taichi/codegen/metal/codegen_metal.cpp index b14cf49780aaa..8879c5efc30d8 100644 --- a/taichi/codegen/metal/codegen_metal.cpp +++ b/taichi/codegen/metal/codegen_metal.cpp @@ -187,14 +187,12 @@ class KernelCodegenImpl : public IRVisitor { Kernel *kernel, const CompiledRuntimeModule *compiled_runtime_module, const std::vector &compiled_snode_trees, - PrintStringTable *print_strtab, const Config &config, OffloadedStmt *offloaded) : mtl_kernel_prefix_(taichi_kernel_name), kernel_(kernel), compiled_runtime_module_(compiled_runtime_module), compiled_snode_trees_(compiled_snode_trees), - print_strtab_(print_strtab), cgen_config_(config), offloaded_(offloaded), ctx_attribs_(*kernel_) { @@ -216,11 +214,13 @@ class KernelCodegenImpl : public IRVisitor { } CompiledKernelData run() { + CompiledKernelData res; + print_strtab_ = &res.print_str_table; + emit_headers(); generate_structs(); generate_kernels(); - CompiledKernelData res; res.kernel_name = mtl_kernel_prefix_; res.kernel_attribs = std::move(ti_kernel_attribs_); res.ctx_attribs = std::move(ctx_attribs_); @@ -1654,7 +1654,7 @@ class KernelCodegenImpl : public IRVisitor { }; std::unordered_map snode_to_roots_; std::unordered_map root_id_to_stmts_; - PrintStringTable *const print_strtab_; + PrintStringTable *print_strtab_{nullptr}; const Config &cgen_config_; OffloadedStmt *const offloaded_; @@ -1675,7 +1675,6 @@ CompiledKernelData run_codegen( const CompiledRuntimeModule *compiled_runtime_module, const std::vector &compiled_snode_trees, Kernel *kernel, - PrintStringTable *strtab, OffloadedStmt *offloaded) { const auto id = Program::get_kernel_id(); const auto taichi_kernel_name( @@ -1685,8 +1684,7 @@ CompiledKernelData run_codegen( cgen_config.allow_simdgroup = EnvConfig::instance().is_simdgroup_enabled(); KernelCodegenImpl codegen(taichi_kernel_name, kernel, compiled_runtime_module, - compiled_snode_trees, strtab, cgen_config, - offloaded); + compiled_snode_trees, cgen_config, offloaded); return codegen.run(); } @@ -1697,12 +1695,9 @@ FunctionType compile_to_metal_executable( const CompiledRuntimeModule *compiled_runtime_module, const std::vector &compiled_snode_trees, OffloadedStmt *offloaded) { - const auto compiled_res = - run_codegen(compiled_runtime_module, compiled_snode_trees, kernel, - kernel_mgr->print_strtable(), offloaded); - kernel_mgr->register_taichi_kernel( - compiled_res.kernel_name, compiled_res.source_code, - compiled_res.kernel_attribs, compiled_res.ctx_attribs, kernel); + const auto compiled_res = run_codegen( + compiled_runtime_module, compiled_snode_trees, kernel, offloaded); + kernel_mgr->register_taichi_kernel(compiled_res); return [kernel_mgr, kernel_name = compiled_res.kernel_name](RuntimeContext &ctx) { kernel_mgr->launch_taichi_kernel(kernel_name, &ctx); diff --git a/taichi/codegen/metal/codegen_metal.h b/taichi/codegen/metal/codegen_metal.h index 30d5a26b913aa..12003e22e44b4 100644 --- a/taichi/codegen/metal/codegen_metal.h +++ b/taichi/codegen/metal/codegen_metal.h @@ -20,7 +20,6 @@ CompiledKernelData run_codegen( const CompiledRuntimeModule *compiled_runtime_module, const std::vector &compiled_snode_trees, Kernel *kernel, - PrintStringTable *print_strtab, OffloadedStmt *offloaded); // If |offloaded| is nullptr, this compiles the AST in |kernel|. Otherwise it diff --git a/taichi/runtime/metal/aot_module_builder_impl.cpp b/taichi/runtime/metal/aot_module_builder_impl.cpp index ffdcba25a3b44..532cd060e4b03 100644 --- a/taichi/runtime/metal/aot_module_builder_impl.cpp +++ b/taichi/runtime/metal/aot_module_builder_impl.cpp @@ -55,7 +55,7 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir, void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, Kernel *kernel) { auto compiled = run_codegen(compiled_runtime_module_, compiled_snode_trees_, - kernel, &strtab_, /*offloaded=*/nullptr); + kernel, /*offloaded=*/nullptr); compiled.kernel_name = identifier; ti_aot_data_.kernels.push_back(std::move(compiled)); } @@ -89,7 +89,7 @@ void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier, const std::string &key, Kernel *kernel) { auto compiled = run_codegen(compiled_runtime_module_, compiled_snode_trees_, - kernel, &strtab_, /*offloaded=*/nullptr); + kernel, /*offloaded=*/nullptr); for (auto &k : ti_aot_data_.tmpl_kernels) { if (k.kernel_bundle_name == identifier) { k.kernel_tmpl_map.insert(std::make_pair(key, compiled)); diff --git a/taichi/runtime/metal/aot_module_builder_impl.h b/taichi/runtime/metal/aot_module_builder_impl.h index 04846a0d3f211..7ce10bccf4e89 100644 --- a/taichi/runtime/metal/aot_module_builder_impl.h +++ b/taichi/runtime/metal/aot_module_builder_impl.h @@ -44,7 +44,6 @@ class AotModuleBuilderImpl : public AotModuleBuilder { const CompiledRuntimeModule *compiled_runtime_module_; const std::vector &compiled_snode_trees_; const std::unordered_set fields_; - PrintStringTable strtab_; TaichiAotData ti_aot_data_; }; diff --git a/taichi/runtime/metal/aot_module_loader_impl.cpp b/taichi/runtime/metal/aot_module_loader_impl.cpp index 4008f62399c52..95cc44850a84a 100644 --- a/taichi/runtime/metal/aot_module_loader_impl.cpp +++ b/taichi/runtime/metal/aot_module_loader_impl.cpp @@ -70,9 +70,7 @@ class AotModuleImpl : public aot::Module { return nullptr; } auto *kernel_data = itr->second; - runtime_->register_taichi_kernel( - name, kernel_data->source_code, kernel_data->kernel_attribs, - kernel_data->ctx_attribs, /*kernel=*/nullptr); + runtime_->register_taichi_kernel(*kernel_data); return std::make_unique(runtime_, name); } diff --git a/taichi/runtime/metal/kernel_manager.cpp b/taichi/runtime/metal/kernel_manager.cpp index 0ed8977afc806..46a955e9e1905 100644 --- a/taichi/runtime/metal/kernel_manager.cpp +++ b/taichi/runtime/metal/kernel_manager.cpp @@ -286,20 +286,20 @@ class CompiledTaichiKernel { public: struct Params { std::string mtl_source_code; - const TaichiKernelAttributes *ti_kernel_attribs; - const KernelContextAttributes *ctx_attribs; - MTLDevice *device; - MemoryPool *mem_pool; - KernelProfilerBase *profiler; - const CompileConfig *compile_config; - const Kernel *kernel; - Device *rhi_device; + const TaichiKernelAttributes *ti_kernel_attribs{nullptr}; + const KernelContextAttributes *ctx_attribs{nullptr}; + const PrintStringTable *print_str_table{nullptr}; + MTLDevice *device{nullptr}; + MemoryPool *mem_pool{nullptr}; + KernelProfilerBase *profiler{nullptr}; + const CompileConfig *compile_config{nullptr}; + Device *rhi_device{nullptr}; }; CompiledTaichiKernel(Params params) : ti_kernel_attribs(*params.ti_kernel_attribs), ctx_attribs(*params.ctx_attribs), - kernel_(params.kernel), + print_str_table(*params.print_str_table), rhi_device_(params.rhi_device) { auto *const device = params.device; auto kernel_lib = new_library_with_source( @@ -420,6 +420,7 @@ class CompiledTaichiKernel { std::vector> compiled_mtl_kernels; TaichiKernelAttributes ti_kernel_attribs; KernelContextAttributes ctx_attribs; + PrintStringTable print_str_table; std::unique_ptr ctx_mem; nsobj_unique_ptr ctx_buffer; @@ -430,7 +431,6 @@ class CompiledTaichiKernel { std::unordered_map ext_arr_arg_to_dev_alloc; private: - const Kernel *const kernel_; Device *const rhi_device_; }; @@ -726,34 +726,30 @@ class KernelManager::Impl { root_buffers_.push_back(std::move(rtbuf)); } - void register_taichi_kernel(const std::string &taichi_kernel_name, - const std::string &mtl_kernel_source_code, - const TaichiKernelAttributes &ti_kernel_attribs, - const KernelContextAttributes &ctx_attribs, - const Kernel *kernel) { - TI_ASSERT(compiled_taichi_kernels_.find(taichi_kernel_name) == + void register_taichi_kernel(const CompiledKernelData &compiled_kernel) { + TI_ASSERT(compiled_taichi_kernels_.find(compiled_kernel.kernel_name) == compiled_taichi_kernels_.end()); if (config_->print_kernel_llvm_ir) { // If users have enabled |print_kernel_llvm_ir|, it probably means that // they want to see the compiled code on the given arch. Maybe rename this // flag, or add another flag (e.g. |print_kernel_source_code|)? - TI_INFO("Metal source code for kernel <{}>\n{}", taichi_kernel_name, - mtl_kernel_source_code); + TI_INFO("Metal source code for kernel <{}>\n{}", + compiled_kernel.kernel_name, compiled_kernel.source_code); } CompiledTaichiKernel::Params params; - params.mtl_source_code = mtl_kernel_source_code; - params.ti_kernel_attribs = &ti_kernel_attribs; - params.ctx_attribs = &ctx_attribs; + params.mtl_source_code = compiled_kernel.source_code; + params.ti_kernel_attribs = &compiled_kernel.kernel_attribs; + params.ctx_attribs = &compiled_kernel.ctx_attribs; + params.print_str_table = &compiled_kernel.print_str_table; params.device = device_.get(); params.mem_pool = mem_pool_; params.profiler = profiler_; params.compile_config = config_; - params.kernel = kernel; params.rhi_device = rhi_device_.get(); - compiled_taichi_kernels_[taichi_kernel_name] = + compiled_taichi_kernels_[compiled_kernel.kernel_name] = std::make_unique(params); - TI_DEBUG("Registered Taichi kernel <{}>", taichi_kernel_name); + TI_DEBUG("Registered Taichi kernel <{}>", compiled_kernel.kernel_name); } void launch_taichi_kernel(const std::string &taichi_kernel_name, @@ -812,10 +808,10 @@ class KernelManager::Impl { ctx_blitter->metal_to_host(); } if (used.assertion) { - check_assertion_failure(); + check_assertion_failure(cti_kernel.print_str_table); } if (used.print) { - flush_print_buffers(); + flush_print_buffers(cti_kernel.print_str_table); } } } @@ -829,10 +825,6 @@ class KernelManager::Impl { return buffer_meta_data_; } - PrintStringTable *print_strtable() { - return &print_strtable_; - } - std::size_t get_snode_num_dynamically_allocated(SNode *snode) { // TODO(k-ye): Have a generic way for querying these sparse runtime stats. mac::ScopedAutoreleasePool pool; @@ -1100,7 +1092,7 @@ class KernelManager::Impl { // print_runtime_debug(); } - void check_assertion_failure() { + void check_assertion_failure(const PrintStringTable &print_str_table) { // TODO: Copy this to program's result_buffer, and let the Taichi runtime // handle the assertion failures uniformly. auto *asst_rec = reinterpret_cast( @@ -1112,7 +1104,7 @@ class KernelManager::Impl { shaders::PrintMsg msg(msg_ptr, asst_rec->num_args); using MsgType = shaders::PrintMsg::Type; TI_ASSERT(msg.pm_get_type(0) == MsgType::Str); - const auto fmt_str = print_strtable_.get(msg.pm_get_data(0)); + const auto fmt_str = print_str_table.get(msg.pm_get_data(0)); const auto err_str = format_error_message(fmt_str, [&msg](int argument_id) { // +1 to skip the first arg, which is the error message template. const int32 x = msg.pm_get_data(argument_id + 1); @@ -1141,7 +1133,7 @@ class KernelManager::Impl { throw TaichiAssertionError(err_str); } - void flush_print_buffers() { + void flush_print_buffers(const PrintStringTable &print_str_table) { auto *pa = reinterpret_cast( print_assert_idevalloc_.mem->ptr() + shaders::kMetalAssertBufferSize); const int used_sz = @@ -1166,7 +1158,7 @@ class KernelManager::Impl { } else if (dt == MsgType::F32) { py_cout << *reinterpret_cast(&x); } else if (dt == MsgType::Str) { - py_cout << print_strtable_.get(x); + py_cout << print_str_table.get(x); } else { TI_ERROR("Unexpected data type={}", dt); } @@ -1274,7 +1266,6 @@ class KernelManager::Impl { int last_snode_id_used_in_runtime_{-1}; std::unordered_map> compiled_taichi_kernels_; - PrintStringTable print_strtable_; // The |dev_*_mirror_|s are the data structures stored in the Metal device // side that get mirrored to the host side. This is possible because the @@ -1301,11 +1292,7 @@ class KernelManager::Impl { TI_ERROR("Metal not supported on the current OS"); } - void register_taichi_kernel(const std::string &taichi_kernel_name, - const std::string &mtl_kernel_source_code, - const TaichiKernelAttributes &ti_kernel_attribs, - const KernelContextAttributes &ctx_attribs, - const Kernel *kernel) { + void register_taichi_kernel(const CompiledKernelData &) { TI_ERROR("Metal not supported on the current OS"); } @@ -1351,14 +1338,8 @@ void KernelManager::add_compiled_snode_tree(const CompiledStructs &snode_tree) { impl_->add_compiled_snode_tree(snode_tree); } -void KernelManager::register_taichi_kernel( - const std::string &taichi_kernel_name, - const std::string &mtl_kernel_source_code, - const TaichiKernelAttributes &ti_kernel_attribs, - const KernelContextAttributes &ctx_attribs, - const Kernel *kernel) { - impl_->register_taichi_kernel(taichi_kernel_name, mtl_kernel_source_code, - ti_kernel_attribs, ctx_attribs, kernel); +void KernelManager::register_taichi_kernel(const CompiledKernelData &compiled) { + impl_->register_taichi_kernel(compiled); } void KernelManager::launch_taichi_kernel(const std::string &taichi_kernel_name, @@ -1374,10 +1355,6 @@ BufferMetaData KernelManager::get_buffer_meta_data() { return impl_->get_buffer_meta_data(); } -PrintStringTable *KernelManager::print_strtable() { - return impl_->print_strtable(); -} - std::size_t KernelManager::get_snode_num_dynamically_allocated(SNode *snode) { return impl_->get_snode_num_dynamically_allocated(snode); } diff --git a/taichi/runtime/metal/kernel_manager.h b/taichi/runtime/metal/kernel_manager.h index cde2c34befd87..d073c20cd23ba 100644 --- a/taichi/runtime/metal/kernel_manager.h +++ b/taichi/runtime/metal/kernel_manager.h @@ -48,11 +48,7 @@ class KernelManager { // TODO(k-ye): Remove |taichi_kernel_name| now that it's part of // |ti_kernel_attribs|. Return a handle that will be passed to // launch_taichi_kernel(), instead of using kernel name as the identifier. - void register_taichi_kernel(const std::string &taichi_kernel_name, - const std::string &mtl_kernel_source_code, - const TaichiKernelAttributes &ti_kernel_attribs, - const KernelContextAttributes &ctx_attribs, - const Kernel *kernel); + void register_taichi_kernel(const CompiledKernelData &compiled_kernel); // Launch the given |taichi_kernel_name|. // Kernel launching is asynchronous, therefore the Metal memory is not valid diff --git a/taichi/runtime/metal/kernel_utils.cpp b/taichi/runtime/metal/kernel_utils.cpp index 6ee2ffc7c520c..dea1fc0dbe806 100644 --- a/taichi/runtime/metal/kernel_utils.cpp +++ b/taichi/runtime/metal/kernel_utils.cpp @@ -22,7 +22,7 @@ int PrintStringTable::put(const std::string &str) { return i; } -const std::string &PrintStringTable::get(int i) { +const std::string &PrintStringTable::get(int i) const { return strs_[i]; } diff --git a/taichi/runtime/metal/kernel_utils.h b/taichi/runtime/metal/kernel_utils.h index edb7f9c41ada4..ef411f36aff2a 100644 --- a/taichi/runtime/metal/kernel_utils.h +++ b/taichi/runtime/metal/kernel_utils.h @@ -27,7 +27,9 @@ namespace metal { class PrintStringTable { public: int put(const std::string &str); - const std::string &get(int i); + const std::string &get(int i) const; + + TI_IO_DEF(strs_); private: std::vector strs_; @@ -279,8 +281,9 @@ struct CompiledKernelData { std::string source_code; KernelContextAttributes ctx_attribs; TaichiKernelAttributes kernel_attribs; + PrintStringTable print_str_table; - TI_IO_DEF(kernel_name, ctx_attribs, kernel_attribs); + TI_IO_DEF(kernel_name, ctx_attribs, kernel_attribs, print_str_table); }; struct CompiledKernelTmplData {