diff --git a/c_api/include/taichi/cpp/taichi.hpp b/c_api/include/taichi/cpp/taichi.hpp index 7d113adcf76d41..ea98278d0c92f5 100644 --- a/c_api/include/taichi/cpp/taichi.hpp +++ b/c_api/include/taichi/cpp/taichi.hpp @@ -797,6 +797,14 @@ class Runtime { return load_aot_module(path.c_str()); } + AotModule create_aot_module(const void *tcm, size_t size) { + TiAotModule aot_module = ti_create_aot_module(runtime_, tcm, size); + return AotModule(runtime_, aot_module, true); + } + AotModule create_aot_module(const std::vector &tcm) { + return create_aot_module(tcm.data(), tcm.size()); + } + void copy_memory_device_to_device(const TiMemorySlice &dst_memory, const TiMemorySlice &src_memory) { ti_copy_memory_device_to_device(runtime_, &dst_memory, &src_memory); diff --git a/c_api/include/taichi/taichi_core.h b/c_api/include/taichi/taichi_core.h index 0187745af96817..3242fec7e60787 100644 --- a/c_api/include/taichi/taichi_core.h +++ b/c_api/include/taichi/taichi_core.h @@ -995,6 +995,11 @@ TI_DLL_EXPORT void TI_API_CALL ti_wait(TiRuntime runtime); TI_DLL_EXPORT TiAotModule TI_API_CALL ti_load_aot_module(TiRuntime runtime, const char *module_path); +// Function `ti_create_aot_module` +TI_DLL_EXPORT TiAotModule TI_API_CALL ti_create_aot_module(TiRuntime runtime, + const void *tcm, + uint64_t size); + // Function `ti_destroy_aot_module` // // Destroys a loaded AOT module and releases all related resources. diff --git a/c_api/src/taichi_core_impl.cpp b/c_api/src/taichi_core_impl.cpp index 5df3d3b08d6d2e..8616b5b25703dc 100644 --- a/c_api/src/taichi_core_impl.cpp +++ b/c_api/src/taichi_core_impl.cpp @@ -4,6 +4,7 @@ #include "taichi_llvm_impl.h" #include "taichi/program/ndarray.h" #include "taichi/program/texture.h" +#include "taichi/common/virtual_dir.h" struct ErrorCache { TiError error{TI_ERROR_SUCCESS}; @@ -21,7 +22,7 @@ const char *describe_error(TiError error) { case TI_ERROR_NOT_SUPPORTED: return "not supported"; case TI_ERROR_CORRUPTED_DATA: - return "path not found"; + return "corrupted data"; case TI_ERROR_NAME_NOT_FOUND: return "name not found"; case TI_ERROR_INVALID_ARGUMENT: @@ -492,6 +493,8 @@ TiAotModule ti_load_aot_module(TiRuntime runtime, const char *module_path) { TI_CAPI_ARGUMENT_NULL_RV(runtime); TI_CAPI_ARGUMENT_NULL_RV(module_path); + // (penguinliong) Should call `create_aot_module` directly after all backends + // adapted to it. TiAotModule aot_module = ((Runtime *)runtime)->load_aot_module(module_path); if (aot_module == TI_NULL_HANDLE) { @@ -502,6 +505,26 @@ TiAotModule ti_load_aot_module(TiRuntime runtime, const char *module_path) { TI_CAPI_TRY_CATCH_END(); return out; } +TiAotModule ti_create_aot_module(TiRuntime runtime, + const void *tcm, + uint64_t size) { + TiAotModule out = TI_NULL_HANDLE; + TI_CAPI_TRY_CATCH_BEGIN(); + TI_CAPI_ARGUMENT_NULL_RV(runtime); + TI_CAPI_ARGUMENT_NULL_RV(tcm); + + auto dir = taichi::io::VirtualDir::from_zip(tcm, size); + if (dir == TI_NULL_HANDLE) { + ti_set_last_error(TI_ERROR_CORRUPTED_DATA, "tcm"); + return TI_NULL_HANDLE; + } + + Error err = ((Runtime *)runtime)->create_aot_module(dir.get(), out); + err.set_last_error(); + + TI_CAPI_TRY_CATCH_END(); + return out; +} void ti_destroy_aot_module(TiAotModule aot_module) { TI_CAPI_TRY_CATCH_BEGIN(); TI_CAPI_ARGUMENT_NULL(aot_module); diff --git a/c_api/src/taichi_core_impl.h b/c_api/src/taichi_core_impl.h index 70c4ce7704fd24..b7fcf57b943536 100644 --- a/c_api/src/taichi_core_impl.h +++ b/c_api/src/taichi_core_impl.h @@ -31,6 +31,7 @@ #include "taichi/rhi/device.h" #include "taichi/aot/graph_data.h" #include "taichi/aot/module_loader.h" +#include "taichi/common/virtual_dir.h" #define TI_CAPI_NOT_SUPPORTED(x) ti_set_last_error(TI_ERROR_NOT_SUPPORTED, #x); #define TI_CAPI_NOT_SUPPORTED_IF(x) \ @@ -89,6 +90,28 @@ ti_set_last_error(TI_ERROR_INVALID_STATE, "c++ exception"); \ } +struct Error { + TiError error; + std::string message; + + Error(TiError error, const std::string &message) + : error(error), message(message) { + } + Error() : error(TI_ERROR_SUCCESS), message() { + } + Error(const Error &) = delete; + Error(Error &&) = default; + Error &operator=(const Error &) = delete; + Error &operator=(Error &&) = default; + + // Set this error as the last error if it's not `TI_ERROR_SUCCESS`. + inline void set_last_error() const { + if (error != TI_ERROR_SUCCESS) { + ti_set_last_error(error, message.c_str()); + } + } +}; + class Runtime { protected: // 32 is a magic number in `taichi/inc/constants.h`. @@ -104,7 +127,18 @@ class Runtime { virtual taichi::lang::Device &get() = 0; - virtual TiAotModule load_aot_module(const char *module_path) = 0; + [[deprecated("create_aot_module")]] virtual TiAotModule load_aot_module( + const char *module_path) { + auto dir = taichi::io::VirtualDir::open(module_path); + TiAotModule aot_module = TI_NULL_HANDLE; + Error err = create_aot_module(dir.get(), aot_module); + err.set_last_error(); + return aot_module; + } + virtual Error create_aot_module(const taichi::io::VirtualDir *dir, + TiAotModule &out) { + TI_NOT_IMPLEMENTED + } virtual TiMemory allocate_memory( const taichi::lang::Device::AllocParams ¶ms); virtual void free_memory(TiMemory devmem); diff --git a/c_api/src/taichi_gfx_impl.cpp b/c_api/src/taichi_gfx_impl.cpp index e5dddb5a0a5dfb..8e7ed41f7e0e8e 100644 --- a/c_api/src/taichi_gfx_impl.cpp +++ b/c_api/src/taichi_gfx_impl.cpp @@ -4,14 +4,15 @@ GfxRuntime::GfxRuntime(taichi::Arch arch) : Runtime(arch) { } -TiAotModule GfxRuntime::load_aot_module(const char *module_path) { +Error GfxRuntime::create_aot_module(const taichi::io::VirtualDir *dir, + TiAotModule &out) { taichi::lang::gfx::AotModuleParams params{}; - params.module_path = module_path; + params.dir = dir; params.runtime = &get_gfx_runtime(); std::unique_ptr aot_module = taichi::lang::aot::Module::load(arch, params); if (aot_module->is_corrupted()) { - return TI_NULL_HANDLE; + return Error(TI_ERROR_CORRUPTED_DATA, "aot_module"); } const taichi::lang::DeviceCapabilityConfig ¤t_devcaps = @@ -21,15 +22,16 @@ TiAotModule GfxRuntime::load_aot_module(const char *module_path) { for (const auto &pair : required_devcaps.devcaps) { uint32_t current_version = current_devcaps.get(pair.first); uint32_t required_version = pair.second; - if (current_version == required_version) { - ti_set_last_error(TI_ERROR_INCOMPATIBLE_MODULE, - taichi::lang::to_string(pair.first).c_str()); + if (current_version != required_version) { + return Error(TI_ERROR_INCOMPATIBLE_MODULE, + taichi::lang::to_string(pair.first).c_str()); } } size_t root_size = aot_module->get_root_size(); params.runtime->add_root_buffer(root_size); - return (TiAotModule)(new AotModule(*this, std::move(aot_module))); + out = (TiAotModule) new AotModule(*this, std::move(aot_module)); + return Error(); } void GfxRuntime::buffer_copy(const taichi::lang::DevicePtr &dst, const taichi::lang::DevicePtr &src, diff --git a/c_api/src/taichi_gfx_impl.h b/c_api/src/taichi_gfx_impl.h index c95c2c3ba30eac..86aeb4a3cb78ca 100644 --- a/c_api/src/taichi_gfx_impl.h +++ b/c_api/src/taichi_gfx_impl.h @@ -2,6 +2,7 @@ #include "taichi_core_impl.h" #include "taichi/runtime/gfx/runtime.h" +#include "taichi/common/virtual_dir.h" class GfxRuntime; @@ -10,7 +11,8 @@ class GfxRuntime : public Runtime { GfxRuntime(taichi::Arch arch); virtual taichi::lang::gfx::GfxRuntime &get_gfx_runtime() = 0; - virtual TiAotModule load_aot_module(const char *module_path) override final; + virtual Error create_aot_module(const taichi::io::VirtualDir *dir, + TiAotModule &out) override final; virtual void buffer_copy(const taichi::lang::DevicePtr &dst, const taichi::lang::DevicePtr &src, size_t size) override final; diff --git a/c_api/taichi.json b/c_api/taichi.json index 99d79468514026..68e59ae0a880a2 100644 --- a/c_api/taichi.json +++ b/c_api/taichi.json @@ -866,6 +866,27 @@ } ] }, + { + "name": "create_aot_module", + "type": "function", + "parameters": [ + { + "name": "@return", + "type": "handle.aot_module" + }, + { + "type": "handle.runtime" + }, + { + "name": "tcm", + "type": "const void*" + }, + { + "name": "size", + "type": "uint64_t" + } + ] + }, { "name": "destroy_aot_module", "type": "function", diff --git a/c_api/tests/c_api_interface_test.cpp b/c_api/tests/c_api_interface_test.cpp index e95073e3202f34..477d8ca5ea9440 100644 --- a/c_api/tests/c_api_interface_test.cpp +++ b/c_api/tests/c_api_interface_test.cpp @@ -165,3 +165,40 @@ TEST_F(CapiTest, TestLoadTcmAotModule) { } } } + +TEST_F(CapiTest, TestCreateTcmAotModule) { + if (capi::utils::is_vulkan_available()) { + const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH"); + + std::stringstream aot_mod_ss; + aot_mod_ss << folder_dir << "/module.tcm"; + + std::vector tcm; + { + std::fstream f(aot_mod_ss.str(), + std::ios::in | std::ios::binary | std::ios::ate); + TI_ASSERT(f.is_open()); + tcm.resize(f.tellg()); + f.seekg(std::ios::beg); + f.read((char *)tcm.data(), tcm.size()); + } + + { + // Vulkan Runtime + TiArch arch = TiArch::TI_ARCH_VULKAN; + ti::Runtime runtime(arch); + ti::AotModule aot_mod = runtime.create_aot_module(tcm); + ti::Kernel run = aot_mod.get_kernel("run"); + ti::NdArray arr = + runtime.allocate_ndarray({16}, {}, true); + run[0] = arr; + run.launch(); + runtime.wait(); + std::vector data(16); + arr.read(data); + for (int32_t i = 0; i < 16; ++i) { + TI_ASSERT(data.at(i) == i); + } + } + } +} diff --git a/taichi/common/virtual_dir.cpp b/taichi/common/virtual_dir.cpp index 4b1cd7f0941728..d41baa07a5739e 100644 --- a/taichi/common/virtual_dir.cpp +++ b/taichi/common/virtual_dir.cpp @@ -60,9 +60,11 @@ struct ZipArchiveVirtualDir : public VirtualDir { f.seekg(std::ios::beg); f.read((char *)archive_data.data(), archive_data.size()); + return from_zip(archive_data.data(), archive_data.size()); + } + static std::unique_ptr from_zip(const void *data, size_t size) { zip::ZipArchive archive; - bool succ = zip::ZipArchive::try_from_bytes(archive_data.data(), - archive_data.size(), archive); + bool succ = zip::ZipArchive::try_from_bytes(data, size, archive); if (!succ) { return nullptr; } @@ -121,5 +123,14 @@ std::unique_ptr VirtualDir::open(const std::string &path) { } } +std::unique_ptr VirtualDir::from_zip(const void *data, + size_t size) { + return ZipArchiveVirtualDir::from_zip(data, size); +} +std::unique_ptr VirtualDir::from_fs_dir( + const std::string &base_dir) { + return FilesystemVirtualDir::create(base_dir); +} + } // namespace io } // namespace taichi diff --git a/taichi/common/virtual_dir.h b/taichi/common/virtual_dir.h index 9772d29eaecfbd..66fe9b6ad4440a 100644 --- a/taichi/common/virtual_dir.h +++ b/taichi/common/virtual_dir.h @@ -12,6 +12,8 @@ struct TI_DLL_EXPORT VirtualDir { // Open a virtual directory based on what `path` points to. Zip files and // filesystem directories are supported. static std::unique_ptr open(const std::string &path); + static std::unique_ptr from_zip(const void *data, size_t size); + static std::unique_ptr from_fs_dir(const std::string &base_dir); // Get the `size` of the file at `path` in the virtual directory. Returns // false when the file doesn't exist. diff --git a/taichi/runtime/gfx/aot_module_loader_impl.cpp b/taichi/runtime/gfx/aot_module_loader_impl.cpp index 3ce2c45fdcf45e..16536536d0a8d3 100644 --- a/taichi/runtime/gfx/aot_module_loader_impl.cpp +++ b/taichi/runtime/gfx/aot_module_loader_impl.cpp @@ -5,7 +5,6 @@ #include "taichi/runtime/gfx/runtime.h" #include "taichi/aot/graph_data.h" -#include "taichi/common/virtual_dir.h" namespace taichi::lang { namespace gfx { @@ -27,9 +26,10 @@ class AotModuleImpl : public aot::Module { : module_path_(params.module_path), runtime_(params.runtime), device_api_backend_(device_api_backend) { - auto dir = io::VirtualDir::open(params.module_path); - TI_ERROR_IF(dir == nullptr, "cannot open aot module '{}'", - params.module_path); + std::unique_ptr dir_alt = + io::VirtualDir::from_fs_dir(module_path_); + const io::VirtualDir *dir = + params.dir == nullptr ? dir_alt.get() : params.dir; bool succ = true; diff --git a/taichi/runtime/gfx/aot_module_loader_impl.h b/taichi/runtime/gfx/aot_module_loader_impl.h index 2557d07a599137..ee74d2fdd6b243 100644 --- a/taichi/runtime/gfx/aot_module_loader_impl.h +++ b/taichi/runtime/gfx/aot_module_loader_impl.h @@ -11,20 +11,25 @@ #include "taichi/codegen/spirv/kernel_utils.h" #include "taichi/aot/module_builder.h" #include "taichi/aot/module_loader.h" +#include "taichi/common/virtual_dir.h" namespace taichi::lang { namespace gfx { struct TI_DLL_EXPORT AotModuleParams { - std::string module_path; + std::string module_path{}; + const io::VirtualDir *dir{nullptr}; GfxRuntime *runtime{nullptr}; bool enable_lazy_loading{false}; AotModuleParams() = default; - AotModuleParams(const std::string &path, GfxRuntime *rt) + [[deprecated]] AotModuleParams(const std::string &path, GfxRuntime *rt) : module_path(path), runtime(rt) { } + AotModuleParams(const io::VirtualDir *dir, GfxRuntime *rt) + : dir(dir), runtime(rt) { + } }; TI_DLL_EXPORT std::unique_ptr make_aot_module( diff --git a/tests/test_config.json b/tests/test_config.json index e819a0ec87470b..0064be40d6d1a1 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -198,6 +198,10 @@ "CapiTest.TestLoadTcmAotModule": [ ["cpp", "aot", "python_scripts", "tcm_test.py"], "--arch=vulkan" + ], + "CapiTest.TestCreateTcmAotModule": [ + ["cpp", "aot", "python_scripts", "tcm_test.py"], + "--arch=vulkan" ] } }