Skip to content

Commit

Permalink
[aot] Load AOT module from memory (#6692) (#6714)
Browse files Browse the repository at this point in the history
  • Loading branch information
PENGUINLIONG authored Nov 28, 2022
1 parent 37efd6d commit 489205c
Show file tree
Hide file tree
Showing 13 changed files with 172 additions and 18 deletions.
8 changes: 8 additions & 0 deletions c_api/include/taichi/cpp/taichi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,14 @@ class Runtime {
return load_aot_module(path.c_str());
}

AotModule create_aot_module(const void *tcm, size_t size) {
TiAotModule aot_module = ti_create_aot_module(runtime_, tcm, size);
return AotModule(runtime_, aot_module, true);
}
AotModule create_aot_module(const std::vector<uint8_t> &tcm) {
return create_aot_module(tcm.data(), tcm.size());
}

void copy_memory_device_to_device(const TiMemorySlice &dst_memory,
const TiMemorySlice &src_memory) {
ti_copy_memory_device_to_device(runtime_, &dst_memory, &src_memory);
Expand Down
5 changes: 5 additions & 0 deletions c_api/include/taichi/taichi_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,11 @@ TI_DLL_EXPORT void TI_API_CALL ti_wait(TiRuntime runtime);
TI_DLL_EXPORT TiAotModule TI_API_CALL
ti_load_aot_module(TiRuntime runtime, const char *module_path);

// Function `ti_create_aot_module`
TI_DLL_EXPORT TiAotModule TI_API_CALL ti_create_aot_module(TiRuntime runtime,
const void *tcm,
uint64_t size);

// Function `ti_destroy_aot_module`
//
// Destroys a loaded AOT module and releases all related resources.
Expand Down
25 changes: 24 additions & 1 deletion c_api/src/taichi_core_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "taichi_llvm_impl.h"
#include "taichi/program/ndarray.h"
#include "taichi/program/texture.h"
#include "taichi/common/virtual_dir.h"

struct ErrorCache {
TiError error{TI_ERROR_SUCCESS};
Expand All @@ -21,7 +22,7 @@ const char *describe_error(TiError error) {
case TI_ERROR_NOT_SUPPORTED:
return "not supported";
case TI_ERROR_CORRUPTED_DATA:
return "path not found";
return "corrupted data";
case TI_ERROR_NAME_NOT_FOUND:
return "name not found";
case TI_ERROR_INVALID_ARGUMENT:
Expand Down Expand Up @@ -492,6 +493,8 @@ TiAotModule ti_load_aot_module(TiRuntime runtime, const char *module_path) {
TI_CAPI_ARGUMENT_NULL_RV(runtime);
TI_CAPI_ARGUMENT_NULL_RV(module_path);

// (penguinliong) Should call `create_aot_module` directly after all backends
// adapted to it.
TiAotModule aot_module = ((Runtime *)runtime)->load_aot_module(module_path);

if (aot_module == TI_NULL_HANDLE) {
Expand All @@ -502,6 +505,26 @@ TiAotModule ti_load_aot_module(TiRuntime runtime, const char *module_path) {
TI_CAPI_TRY_CATCH_END();
return out;
}
TiAotModule ti_create_aot_module(TiRuntime runtime,
const void *tcm,
uint64_t size) {
TiAotModule out = TI_NULL_HANDLE;
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL_RV(runtime);
TI_CAPI_ARGUMENT_NULL_RV(tcm);

auto dir = taichi::io::VirtualDir::from_zip(tcm, size);
if (dir == TI_NULL_HANDLE) {
ti_set_last_error(TI_ERROR_CORRUPTED_DATA, "tcm");
return TI_NULL_HANDLE;
}

Error err = ((Runtime *)runtime)->create_aot_module(dir.get(), out);
err.set_last_error();

TI_CAPI_TRY_CATCH_END();
return out;
}
void ti_destroy_aot_module(TiAotModule aot_module) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(aot_module);
Expand Down
36 changes: 35 additions & 1 deletion c_api/src/taichi_core_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "taichi/rhi/device.h"
#include "taichi/aot/graph_data.h"
#include "taichi/aot/module_loader.h"
#include "taichi/common/virtual_dir.h"

#define TI_CAPI_NOT_SUPPORTED(x) ti_set_last_error(TI_ERROR_NOT_SUPPORTED, #x);
#define TI_CAPI_NOT_SUPPORTED_IF(x) \
Expand Down Expand Up @@ -89,6 +90,28 @@
ti_set_last_error(TI_ERROR_INVALID_STATE, "c++ exception"); \
}

struct Error {
TiError error;
std::string message;

Error(TiError error, const std::string &message)
: error(error), message(message) {
}
Error() : error(TI_ERROR_SUCCESS), message() {
}
Error(const Error &) = delete;
Error(Error &&) = default;
Error &operator=(const Error &) = delete;
Error &operator=(Error &&) = default;

// Set this error as the last error if it's not `TI_ERROR_SUCCESS`.
inline void set_last_error() const {
if (error != TI_ERROR_SUCCESS) {
ti_set_last_error(error, message.c_str());
}
}
};

class Runtime {
protected:
// 32 is a magic number in `taichi/inc/constants.h`.
Expand All @@ -104,7 +127,18 @@ class Runtime {

virtual taichi::lang::Device &get() = 0;

virtual TiAotModule load_aot_module(const char *module_path) = 0;
[[deprecated("create_aot_module")]] virtual TiAotModule load_aot_module(
const char *module_path) {
auto dir = taichi::io::VirtualDir::open(module_path);
TiAotModule aot_module = TI_NULL_HANDLE;
Error err = create_aot_module(dir.get(), aot_module);
err.set_last_error();
return aot_module;
}
virtual Error create_aot_module(const taichi::io::VirtualDir *dir,
TiAotModule &out) {
TI_NOT_IMPLEMENTED
}
virtual TiMemory allocate_memory(
const taichi::lang::Device::AllocParams &params);
virtual void free_memory(TiMemory devmem);
Expand Down
16 changes: 9 additions & 7 deletions c_api/src/taichi_gfx_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
GfxRuntime::GfxRuntime(taichi::Arch arch) : Runtime(arch) {
}

TiAotModule GfxRuntime::load_aot_module(const char *module_path) {
Error GfxRuntime::create_aot_module(const taichi::io::VirtualDir *dir,
TiAotModule &out) {
taichi::lang::gfx::AotModuleParams params{};
params.module_path = module_path;
params.dir = dir;
params.runtime = &get_gfx_runtime();
std::unique_ptr<taichi::lang::aot::Module> aot_module =
taichi::lang::aot::Module::load(arch, params);
if (aot_module->is_corrupted()) {
return TI_NULL_HANDLE;
return Error(TI_ERROR_CORRUPTED_DATA, "aot_module");
}

const taichi::lang::DeviceCapabilityConfig &current_devcaps =
Expand All @@ -21,15 +22,16 @@ TiAotModule GfxRuntime::load_aot_module(const char *module_path) {
for (const auto &pair : required_devcaps.devcaps) {
uint32_t current_version = current_devcaps.get(pair.first);
uint32_t required_version = pair.second;
if (current_version == required_version) {
ti_set_last_error(TI_ERROR_INCOMPATIBLE_MODULE,
taichi::lang::to_string(pair.first).c_str());
if (current_version != required_version) {
return Error(TI_ERROR_INCOMPATIBLE_MODULE,
taichi::lang::to_string(pair.first).c_str());
}
}

size_t root_size = aot_module->get_root_size();
params.runtime->add_root_buffer(root_size);
return (TiAotModule)(new AotModule(*this, std::move(aot_module)));
out = (TiAotModule) new AotModule(*this, std::move(aot_module));
return Error();
}
void GfxRuntime::buffer_copy(const taichi::lang::DevicePtr &dst,
const taichi::lang::DevicePtr &src,
Expand Down
4 changes: 3 additions & 1 deletion c_api/src/taichi_gfx_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "taichi_core_impl.h"
#include "taichi/runtime/gfx/runtime.h"
#include "taichi/common/virtual_dir.h"

class GfxRuntime;

Expand All @@ -10,7 +11,8 @@ class GfxRuntime : public Runtime {
GfxRuntime(taichi::Arch arch);
virtual taichi::lang::gfx::GfxRuntime &get_gfx_runtime() = 0;

virtual TiAotModule load_aot_module(const char *module_path) override final;
virtual Error create_aot_module(const taichi::io::VirtualDir *dir,
TiAotModule &out) override final;
virtual void buffer_copy(const taichi::lang::DevicePtr &dst,
const taichi::lang::DevicePtr &src,
size_t size) override final;
Expand Down
21 changes: 21 additions & 0 deletions c_api/taichi.json
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,27 @@
}
]
},
{
"name": "create_aot_module",
"type": "function",
"parameters": [
{
"name": "@return",
"type": "handle.aot_module"
},
{
"type": "handle.runtime"
},
{
"name": "tcm",
"type": "const void*"
},
{
"name": "size",
"type": "uint64_t"
}
]
},
{
"name": "destroy_aot_module",
"type": "function",
Expand Down
37 changes: 37 additions & 0 deletions c_api/tests/c_api_interface_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,40 @@ TEST_F(CapiTest, TestLoadTcmAotModule) {
}
}
}

TEST_F(CapiTest, TestCreateTcmAotModule) {
if (capi::utils::is_vulkan_available()) {
const auto folder_dir = getenv("TAICHI_AOT_FOLDER_PATH");

std::stringstream aot_mod_ss;
aot_mod_ss << folder_dir << "/module.tcm";

std::vector<uint8_t> tcm;
{
std::fstream f(aot_mod_ss.str(),
std::ios::in | std::ios::binary | std::ios::ate);
TI_ASSERT(f.is_open());
tcm.resize(f.tellg());
f.seekg(std::ios::beg);
f.read((char *)tcm.data(), tcm.size());
}

{
// Vulkan Runtime
TiArch arch = TiArch::TI_ARCH_VULKAN;
ti::Runtime runtime(arch);
ti::AotModule aot_mod = runtime.create_aot_module(tcm);
ti::Kernel run = aot_mod.get_kernel("run");
ti::NdArray<int32_t> arr =
runtime.allocate_ndarray<int32_t>({16}, {}, true);
run[0] = arr;
run.launch();
runtime.wait();
std::vector<int32_t> data(16);
arr.read(data);
for (int32_t i = 0; i < 16; ++i) {
TI_ASSERT(data.at(i) == i);
}
}
}
}
15 changes: 13 additions & 2 deletions taichi/common/virtual_dir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ struct ZipArchiveVirtualDir : public VirtualDir {
f.seekg(std::ios::beg);
f.read((char *)archive_data.data(), archive_data.size());

return from_zip(archive_data.data(), archive_data.size());
}
static std::unique_ptr<VirtualDir> from_zip(const void *data, size_t size) {
zip::ZipArchive archive;
bool succ = zip::ZipArchive::try_from_bytes(archive_data.data(),
archive_data.size(), archive);
bool succ = zip::ZipArchive::try_from_bytes(data, size, archive);
if (!succ) {
return nullptr;
}
Expand Down Expand Up @@ -121,5 +123,14 @@ std::unique_ptr<VirtualDir> VirtualDir::open(const std::string &path) {
}
}

std::unique_ptr<VirtualDir> VirtualDir::from_zip(const void *data,
size_t size) {
return ZipArchiveVirtualDir::from_zip(data, size);
}
std::unique_ptr<VirtualDir> VirtualDir::from_fs_dir(
const std::string &base_dir) {
return FilesystemVirtualDir::create(base_dir);
}

} // namespace io
} // namespace taichi
2 changes: 2 additions & 0 deletions taichi/common/virtual_dir.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct TI_DLL_EXPORT VirtualDir {
// Open a virtual directory based on what `path` points to. Zip files and
// filesystem directories are supported.
static std::unique_ptr<VirtualDir> open(const std::string &path);
static std::unique_ptr<VirtualDir> from_zip(const void *data, size_t size);
static std::unique_ptr<VirtualDir> from_fs_dir(const std::string &base_dir);

// Get the `size` of the file at `path` in the virtual directory. Returns
// false when the file doesn't exist.
Expand Down
8 changes: 4 additions & 4 deletions taichi/runtime/gfx/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#include "taichi/runtime/gfx/runtime.h"
#include "taichi/aot/graph_data.h"
#include "taichi/common/virtual_dir.h"

namespace taichi::lang {
namespace gfx {
Expand All @@ -27,9 +26,10 @@ class AotModuleImpl : public aot::Module {
: module_path_(params.module_path),
runtime_(params.runtime),
device_api_backend_(device_api_backend) {
auto dir = io::VirtualDir::open(params.module_path);
TI_ERROR_IF(dir == nullptr, "cannot open aot module '{}'",
params.module_path);
std::unique_ptr<io::VirtualDir> dir_alt =
io::VirtualDir::from_fs_dir(module_path_);
const io::VirtualDir *dir =
params.dir == nullptr ? dir_alt.get() : params.dir;

bool succ = true;

Expand Down
9 changes: 7 additions & 2 deletions taichi/runtime/gfx/aot_module_loader_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,25 @@
#include "taichi/codegen/spirv/kernel_utils.h"
#include "taichi/aot/module_builder.h"
#include "taichi/aot/module_loader.h"
#include "taichi/common/virtual_dir.h"

namespace taichi::lang {
namespace gfx {

struct TI_DLL_EXPORT AotModuleParams {
std::string module_path;
std::string module_path{};
const io::VirtualDir *dir{nullptr};
GfxRuntime *runtime{nullptr};
bool enable_lazy_loading{false};

AotModuleParams() = default;

AotModuleParams(const std::string &path, GfxRuntime *rt)
[[deprecated]] AotModuleParams(const std::string &path, GfxRuntime *rt)
: module_path(path), runtime(rt) {
}
AotModuleParams(const io::VirtualDir *dir, GfxRuntime *rt)
: dir(dir), runtime(rt) {
}
};

TI_DLL_EXPORT std::unique_ptr<aot::Module> make_aot_module(
Expand Down
4 changes: 4 additions & 0 deletions tests/test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@
"CapiTest.TestLoadTcmAotModule": [
["cpp", "aot", "python_scripts", "tcm_test.py"],
"--arch=vulkan"
],
"CapiTest.TestCreateTcmAotModule": [
["cpp", "aot", "python_scripts", "tcm_test.py"],
"--arch=vulkan"
]
}
}

0 comments on commit 489205c

Please sign in to comment.