diff --git a/taichi/aot/module_loader.h b/taichi/aot/module_loader.h index 634f8462e92ec..354b61d607581 100644 --- a/taichi/aot/module_loader.h +++ b/taichi/aot/module_loader.h @@ -72,9 +72,6 @@ class TI_DLL_EXPORT Module { protected: virtual std::unique_ptr make_new_kernel(const std::string &name) = 0; - - private: - std::unordered_map> loaded_kernels_; }; // Only responsible for reporting device capabilities diff --git a/taichi/backends/metal/aot_module_loader_impl.cpp b/taichi/backends/metal/aot_module_loader_impl.cpp index a02192c6c7115..c5d48cc7c90b8 100644 --- a/taichi/backends/metal/aot_module_loader_impl.cpp +++ b/taichi/backends/metal/aot_module_loader_impl.cpp @@ -8,6 +8,17 @@ namespace lang { namespace metal { namespace { +class FieldImpl : public aot::Field { + public: + explicit FieldImpl(KernelManager *runtime, const CompiledFieldData &field) + : runtime_(runtime), field_(field) { + } + + private: + KernelManager *const runtime_; + CompiledFieldData field_; +}; + class KernelImpl : public aot::Kernel { public: explicit KernelImpl(KernelManager *runtime, const std::string &kernel_name) diff --git a/taichi/backends/vulkan/aot_module_loader_impl.cpp b/taichi/backends/vulkan/aot_module_loader_impl.cpp index 249d0fea570f0..44265cb684ddd 100644 --- a/taichi/backends/vulkan/aot_module_loader_impl.cpp +++ b/taichi/backends/vulkan/aot_module_loader_impl.cpp @@ -12,6 +12,17 @@ namespace { using KernelHandle = VkRuntime::KernelHandle; +class FieldImpl : public aot::Field { + public: + explicit FieldImpl(VkRuntime *runtime, const aot::CompiledFieldData &field) + : runtime_(runtime), field_(field) { + } + + private: + VkRuntime *const runtime_; + aot::CompiledFieldData field_; +}; + class KernelImpl : public aot::Kernel { public: explicit KernelImpl(VkRuntime *runtime, KernelHandle handle) @@ -53,7 +64,12 @@ class AotModuleImpl : public aot::Module { } std::unique_ptr get_field(const std::string &name) override { - TI_NOT_IMPLEMENTED; + aot::CompiledFieldData field; + if (!get_field_data_by_name(name, field)) { + TI_DEBUG("Failed to load field {}", name); + return nullptr; + } + return std::make_unique(runtime_, field); } size_t get_root_size() const override { @@ -69,6 +85,17 @@ class AotModuleImpl : public aot::Module { } private: + bool get_field_data_by_name(const std::string &name, + aot::CompiledFieldData &field) { + for (int i = 0; i < ti_aot_data_.fields.size(); ++i) { + if (ti_aot_data_.fields[i].field_name.rfind(name, 0) == 0) { + field = ti_aot_data_.fields[i]; + return true; + } + } + return false; + } + bool get_kernel_params_by_name(const std::string &name, VkRuntime::RegisterParams &kernel) { for (int i = 0; i < ti_aot_data_.kernels.size(); ++i) { diff --git a/taichi/backends/vulkan/aot_module_loader_impl.h b/taichi/backends/vulkan/aot_module_loader_impl.h index 37e28d01f6388..5329e889420c2 100644 --- a/taichi/backends/vulkan/aot_module_loader_impl.h +++ b/taichi/backends/vulkan/aot_module_loader_impl.h @@ -21,6 +21,7 @@ struct AotModuleParams { }; std::unique_ptr make_aot_module(std::any mod_params); + } // namespace vulkan } // namespace lang } // namespace taichi diff --git a/taichi/backends/vulkan/runtime.cpp b/taichi/backends/vulkan/runtime.cpp index 96325f2c848b6..fe2a1bf4ab04f 100644 --- a/taichi/backends/vulkan/runtime.cpp +++ b/taichi/backends/vulkan/runtime.cpp @@ -591,15 +591,32 @@ void VkRuntime::add_root_buffer(size_t root_buffer_size) { root_buffer_size = 4; // there might be empty roots } std::unique_ptr new_buffer = - device_->allocate_memory_unique( - {root_buffer_size, - /*host_write=*/false, /*host_read=*/false, - /*export_sharing=*/false, AllocUsage::Storage}); + device_->allocate_memory_unique({root_buffer_size, + /*host_write=*/true, /*host_read=*/true, + /*export_sharing=*/false, + AllocUsage::Storage}); Stream *stream = device_->get_compute_stream(); auto cmdlist = stream->new_command_list(); cmdlist->buffer_fill(new_buffer->get_ptr(0), root_buffer_size, /*data=*/0); stream->submit_synced(cmdlist.get()); root_buffers_.push_back(std::move(new_buffer)); + // cache the root buffer size + root_buffers_size_map_[root_buffers_.back().get()] = root_buffer_size; +} + +DeviceAllocation *VkRuntime::get_root_buffer(int id) const { + if (id >= root_buffers_.size()) { + TI_ERROR("root buffer id {} not found", id); + } + return root_buffers_[id].get(); +} + +size_t VkRuntime::get_root_buffer_size(int id) const { + auto it = root_buffers_size_map_.find(root_buffers_[id].get()); + if (id >= root_buffers_.size() || it == root_buffers_size_map_.end()) { + TI_ERROR("root buffer id {} not found", id); + } + return it->second; } VkRuntime::RegisterParams run_codegen( diff --git a/taichi/backends/vulkan/runtime.h b/taichi/backends/vulkan/runtime.h index d862dcf5cf755..fecf9812ccb4f 100644 --- a/taichi/backends/vulkan/runtime.h +++ b/taichi/backends/vulkan/runtime.h @@ -106,6 +106,10 @@ class TI_DLL_EXPORT VkRuntime { void add_root_buffer(size_t root_buffer_size); + DeviceAllocation *get_root_buffer(int id) const; + + size_t get_root_buffer_size(int id) const; + private: friend class taichi::lang::vulkan::SNodeTreeManager; @@ -125,6 +129,8 @@ class TI_DLL_EXPORT VkRuntime { high_res_clock::time_point current_cmdlist_pending_since_; std::vector> ti_kernels_; + + std::unordered_map root_buffers_size_map_; }; VkRuntime::RegisterParams run_codegen( diff --git a/tests/cpp/aot/aot_save_load_test.cpp b/tests/cpp/aot/aot_save_load_test.cpp index 45885d7e28bf8..9810b79d75b5c 100644 --- a/tests/cpp/aot/aot_save_load_test.cpp +++ b/tests/cpp/aot/aot_save_load_test.cpp @@ -127,7 +127,7 @@ TEST(AotSaveLoad, Vulkan) { aot::Module::load(".", Arch::vulkan, mod_params); EXPECT_TRUE(vk_module); - // Retrieve kernels/fields/etc from AOT module to initialize runtime + // Retrieve kernels/fields/etc from AOT module auto root_size = vk_module->get_root_size(); EXPECT_EQ(root_size, 64); vulkan_runtime->add_root_buffer(root_size); @@ -146,8 +146,8 @@ TEST(AotSaveLoad, Vulkan) { ret_kernel->launch(&host_ctx); vulkan_runtime->synchronize(); - // auto x_field = vk_module.get_field("x"); - // EXPECT_TRUE(x_field); - // x_field.copy_to(/*dst=*/x.get()); + // Retrieve data + auto x_field = vk_module->get_field("place"); + EXPECT_NE(x_field, nullptr); } #endif