Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aot] Load AOT module from memory (#6692) #6714

Merged
merged 1 commit into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions c_api/include/taichi/cpp/taichi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,14 @@ class Runtime {
return load_aot_module(path.c_str());
}

AotModule create_aot_module(const void *tcm, size_t size) {
TiAotModule aot_module = ti_create_aot_module(runtime_, tcm, size);
return AotModule(runtime_, aot_module, true);
}
AotModule create_aot_module(const std::vector<uint8_t> &tcm) {
return create_aot_module(tcm.data(), tcm.size());
}

void copy_memory_device_to_device(const TiMemorySlice &dst_memory,
const TiMemorySlice &src_memory) {
ti_copy_memory_device_to_device(runtime_, &dst_memory, &src_memory);
Expand Down
5 changes: 5 additions & 0 deletions c_api/include/taichi/taichi_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,11 @@ TI_DLL_EXPORT void TI_API_CALL ti_wait(TiRuntime runtime);
TI_DLL_EXPORT TiAotModule TI_API_CALL
ti_load_aot_module(TiRuntime runtime, const char *module_path);

// Function `ti_create_aot_module`
TI_DLL_EXPORT TiAotModule TI_API_CALL ti_create_aot_module(TiRuntime runtime,
const void *tcm,
uint64_t size);

// Function `ti_destroy_aot_module`
//
// Destroys a loaded AOT module and releases all related resources.
Expand Down
25 changes: 24 additions & 1 deletion c_api/src/taichi_core_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "taichi_llvm_impl.h"
#include "taichi/program/ndarray.h"
#include "taichi/program/texture.h"
#include "taichi/common/virtual_dir.h"

struct ErrorCache {
TiError error{TI_ERROR_SUCCESS};
Expand All @@ -21,7 +22,7 @@ const char *describe_error(TiError error) {
case TI_ERROR_NOT_SUPPORTED:
return "not supported";
case TI_ERROR_CORRUPTED_DATA:
return "path not found";
return "corrupted data";
case TI_ERROR_NAME_NOT_FOUND:
return "name not found";
case TI_ERROR_INVALID_ARGUMENT:
Expand Down Expand Up @@ -492,6 +493,8 @@ TiAotModule ti_load_aot_module(TiRuntime runtime, const char *module_path) {
TI_CAPI_ARGUMENT_NULL_RV(runtime);
TI_CAPI_ARGUMENT_NULL_RV(module_path);

// (penguinliong) Should call `create_aot_module` directly after all backends
// adapted to it.
TiAotModule aot_module = ((Runtime *)runtime)->load_aot_module(module_path);

if (aot_module == TI_NULL_HANDLE) {
Expand All @@ -502,6 +505,26 @@ TiAotModule ti_load_aot_module(TiRuntime runtime, const char *module_path) {
TI_CAPI_TRY_CATCH_END();
return out;
}
TiAotModule ti_create_aot_module(TiRuntime runtime,
const void *tcm,
uint64_t size) {
TiAotModule out = TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL_RV(runtime);
TI_CAPI_ARGUMENT_NULL_RV(tcm);

auto dir = taichi::io::VirtualDir::from_zip(tcm, size);
if (dir == TI_NULL_HANDLE) {
ti_set_last_error(TI_ERROR_CORRUPTED_DATA, "tcm");
return TI_NULL_HANDLE;
}

Error err = ((Runtime *)runtime)->create_aot_module(dir.get(), out);
err.set_last_error();

TI_CAPI_TRY_CATCH_END();
return out;
}
void ti_destroy_aot_module(TiAotModule aot_module) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(aot_module);
Expand Down
36 changes: 35 additions & 1 deletion c_api/src/taichi_core_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "taichi/rhi/device.h"
#include "taichi/aot/graph_data.h"
#include "taichi/aot/module_loader.h"
#include "taichi/common/virtual_dir.h"

#define TI_CAPI_NOT_SUPPORTED(x) ti_set_last_error(TI_ERROR_NOT_SUPPORTED, #x);
#define TI_CAPI_NOT_SUPPORTED_IF(x) \
Expand Down Expand Up @@ -89,6 +90,28 @@
ti_set_last_error(TI_ERROR_INVALID_STATE, "c++ exception"); \
}

struct Error {
TiError error;
std::string message;

Error(TiError error, const std::string &message)
: error(error), message(message) {
}
Error() : error(TI_ERROR_SUCCESS), message() {
}
Error(const Error &) = delete;
Error(Error &&) = default;
Error &operator=(const Error &) = delete;
Error &operator=(Error &&) = default;

// Set this error as the last error if it's not `TI_ERROR_SUCCESS`.
inline void set_last_error() const {
if (error != TI_ERROR_SUCCESS) {
ti_set_last_error(error, message.c_str());
}
}
};

class Runtime {
protected:
// 32 is a magic number in `taichi/inc/constants.h`.
Expand All @@ -104,7 +127,18 @@ class Runtime {

virtual taichi::lang::Device &get() = 0;

virtual TiAotModule load_aot_module(const char *module_path) = 0;
[[deprecated("create_aot_module")]] virtual TiAotModule load_aot_module(
const char *module_path) {
auto dir = taichi::io::VirtualDir::open(module_path);
TiAotModule aot_module = TI_NULL_HANDLE;
Error err = create_aot_module(dir.get(), aot_module);
err.set_last_error();
return aot_module;
}
virtual Error create_aot_module(const taichi::io::VirtualDir *dir,
TiAotModule &out) {
TI_NOT_IMPLEMENTED
}
virtual TiMemory allocate_memory(
const taichi::lang::Device::AllocParams &params);
virtual void free_memory(TiMemory devmem);
Expand Down
16 changes: 9 additions & 7 deletions c_api/src/taichi_gfx_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
GfxRuntime::GfxRuntime(taichi::Arch arch) : Runtime(arch) {
}

TiAotModule GfxRuntime::load_aot_module(const char *module_path) {
Error GfxRuntime::create_aot_module(const taichi::io::VirtualDir *dir,
TiAotModule &out) {
taichi::lang::gfx::AotModuleParams params{};
params.module_path = module_path;
params.dir = dir;
params.runtime = &get_gfx_runtime();
std::unique_ptr<taichi::lang::aot::Module> aot_module =
taichi::lang::aot::Module::load(arch, params);
if (aot_module->is_corrupted()) {
return TI_NULL_HANDLE;
return Error(TI_ERROR_CORRUPTED_DATA, "aot_module");
}

const taichi::lang::DeviceCapabilityConfig &current_devcaps =
Expand All @@ -21,15 +22,16 @@ TiAotModule GfxRuntime::load_aot_module(const char *module_path) {
for (const auto &pair : required_devcaps.devcaps) {
uint32_t current_version = current_devcaps.get(pair.first);
uint32_t required_version = pair.second;
if (current_version == required_version) {
ti_set_last_error(TI_ERROR_INCOMPATIBLE_MODULE,
taichi::lang::to_string(pair.first).c_str());
if (current_version != required_version) {
return Error(TI_ERROR_INCOMPATIBLE_MODULE,
taichi::lang::to_string(pair.first).c_str());
}
}

size_t root_size = aot_module->get_root_size();
params.runtime->add_root_buffer(root_size);
return (TiAotModule)(new AotModule(*this, std::move(aot_module)));
out = (TiAotModule) new AotModule(*this, std::move(aot_module));
return Error();
}
void GfxRuntime::buffer_copy(const taichi::lang::DevicePtr &dst,
const taichi::lang::DevicePtr &src,
Expand Down
4 changes: 3 additions & 1 deletion c_api/src/taichi_gfx_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "taichi_core_impl.h"
#include "taichi/runtime/gfx/runtime.h"
#include "taichi/common/virtual_dir.h"

class GfxRuntime;

Expand All @@ -10,7 +11,8 @@ class GfxRuntime : public Runtime {
GfxRuntime(taichi::Arch arch);
virtual taichi::lang::gfx::GfxRuntime &get_gfx_runtime() = 0;

virtual TiAotModule load_aot_module(const char *module_path) override final;
virtual Error create_aot_module(const taichi::io::VirtualDir *dir,
TiAotModule &out) override final;
virtual void buffer_copy(const taichi::lang::DevicePtr &dst,
const taichi::lang::DevicePtr &src,
size_t size) override final;
Expand Down
21 changes: 21 additions & 0 deletions c_api/taichi.json
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,27 @@
}
]
},
{
"name": "create_aot_module",
"type": "function",
"parameters": [
{
"name": "@return",
"type": "handle.aot_module"
},
{
"type": "handle.runtime"
},
{
"name": "tcm",
"type": "const void*"
},
{
"name": "size",
"type": "uint64_t"
}
]
},
{
"name": "destroy_aot_module",
"type": "function",
Expand Down
37 changes: 37 additions & 0 deletions c_api/tests/c_api_interface_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,40 @@ TEST_F(CapiTest, TestLoadTcmAotModule) {
}
}
}

TEST_F(CapiTest, TestCreateTcmAotModule) {
if (capi::utils::is_vulkan_available()) {
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");

std::stringstream aot_mod_ss;
aot_mod_ss << folder_dir << "/module.tcm";

std::vector<uint8_t> tcm;
{
std::fstream f(aot_mod_ss.str(),
std::ios::in | std::ios::binary | std::ios::ate);
TI_ASSERT(f.is_open());
tcm.resize(f.tellg());
f.seekg(std::ios::beg);
f.read((char *)tcm.data(), tcm.size());
}

{
// Vulkan Runtime
TiArch arch = TiArch::TI_ARCH_VULKAN;
ti::Runtime runtime(arch);
ti::AotModule aot_mod = runtime.create_aot_module(tcm);
ti::Kernel run = aot_mod.get_kernel("run");
ti::NdArray<int32_t> arr =
runtime.allocate_ndarray<int32_t>({16}, {}, true);
run[0] = arr;
run.launch();
runtime.wait();
std::vector<int32_t> data(16);
arr.read(data);
for (int32_t i = 0; i < 16; ++i) {
TI_ASSERT(data.at(i) == i);
}
}
}
}
15 changes: 13 additions & 2 deletions taichi/common/virtual_dir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ struct ZipArchiveVirtualDir : public VirtualDir {
f.seekg(std::ios::beg);
f.read((char *)archive_data.data(), archive_data.size());

return from_zip(archive_data.data(), archive_data.size());
}
static std::unique_ptr<VirtualDir> from_zip(const void *data, size_t size) {
zip::ZipArchive archive;
bool succ = zip::ZipArchive::try_from_bytes(archive_data.data(),
archive_data.size(), archive);
bool succ = zip::ZipArchive::try_from_bytes(data, size, archive);
if (!succ) {
return nullptr;
}
Expand Down Expand Up @@ -121,5 +123,14 @@ std::unique_ptr<VirtualDir> VirtualDir::open(const std::string &path) {
}
}

std::unique_ptr<VirtualDir> VirtualDir::from_zip(const void *data,
size_t size) {
return ZipArchiveVirtualDir::from_zip(data, size);
}
std::unique_ptr<VirtualDir> VirtualDir::from_fs_dir(
const std::string &base_dir) {
return FilesystemVirtualDir::create(base_dir);
}

} // namespace io
} // namespace taichi
2 changes: 2 additions & 0 deletions taichi/common/virtual_dir.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct TI_DLL_EXPORT VirtualDir {
// Open a virtual directory based on what `path` points to. Zip files and
// filesystem directories are supported.
static std::unique_ptr<VirtualDir> open(const std::string &path);
static std::unique_ptr<VirtualDir> from_zip(const void *data, size_t size);
static std::unique_ptr<VirtualDir> from_fs_dir(const std::string &base_dir);

// Get the `size` of the file at `path` in the virtual directory. Returns
// false when the file doesn't exist.
Expand Down
8 changes: 4 additions & 4 deletions taichi/runtime/gfx/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#include "taichi/runtime/gfx/runtime.h"
#include "taichi/aot/graph_data.h"
#include "taichi/common/virtual_dir.h"

namespace taichi::lang {
namespace gfx {
Expand All @@ -27,9 +26,10 @@ class AotModuleImpl : public aot::Module {
: module_path_(params.module_path),
runtime_(params.runtime),
device_api_backend_(device_api_backend) {
auto dir = io::VirtualDir::open(params.module_path);
TI_ERROR_IF(dir == nullptr, "cannot open aot module '{}'",
params.module_path);
std::unique_ptr<io::VirtualDir> dir_alt =
io::VirtualDir::from_fs_dir(module_path_);
const io::VirtualDir *dir =
params.dir == nullptr ? dir_alt.get() : params.dir;

bool succ = true;

Expand Down
9 changes: 7 additions & 2 deletions taichi/runtime/gfx/aot_module_loader_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,25 @@
#include "taichi/codegen/spirv/kernel_utils.h"
#include "taichi/aot/module_builder.h"
#include "taichi/aot/module_loader.h"
#include "taichi/common/virtual_dir.h"

namespace taichi::lang {
namespace gfx {

struct TI_DLL_EXPORT AotModuleParams {
std::string module_path;
std::string module_path{};
const io::VirtualDir *dir{nullptr};
GfxRuntime *runtime{nullptr};
bool enable_lazy_loading{false};

AotModuleParams() = default;

AotModuleParams(const std::string &path, GfxRuntime *rt)
[[deprecated]] AotModuleParams(const std::string &path, GfxRuntime *rt)
: module_path(path), runtime(rt) {
}
AotModuleParams(const io::VirtualDir *dir, GfxRuntime *rt)
: dir(dir), runtime(rt) {
}
};

TI_DLL_EXPORT std::unique_ptr<aot::Module> make_aot_module(
Expand Down
4 changes: 4 additions & 0 deletions tests/test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@
"CapiTest.TestLoadTcmAotModule": [
["cpp", "aot", "python_scripts", "tcm_test.py"],
"--arch=vulkan"
],
"CapiTest.TestCreateTcmAotModule": [
["cpp", "aot", "python_scripts", "tcm_test.py"],
"--arch=vulkan"
]
}
}