Skip to content

Commit

Permalink
[aot] [refactor] Refactor AOT field API for Vulkan (#4490)
Browse files Browse the repository at this point in the history
* Refactor field APIs for vulkan backend

* Fix

* Auto Format

* Make root buffer readable from host

* Auto Format

* Update taichi/backends/vulkan/aot_module_loader_impl.cpp

Co-authored-by: Ye Kuang <[email protected]>

* With constexpr

* Make functions simpler

* Remove copy_to_host_buffer API

* Update tests/cpp/aot/aot_save_load_test.cpp

Co-authored-by: Ye Kuang <[email protected]>

Co-authored-by: Taichi Gardener <[email protected]>
Co-authored-by: Ye Kuang <[email protected]>
  • Loading branch information
3 people authored Mar 10, 2022
1 parent d196c8d commit 570c39d
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 12 deletions.
3 changes: 0 additions & 3 deletions taichi/aot/module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ class TI_DLL_EXPORT Module {

protected:
virtual std::unique_ptr<Kernel> make_new_kernel(const std::string &name) = 0;

private:
std::unordered_map<std::string, std::unique_ptr<Kernel>> loaded_kernels_;
};

// Only responsible for reporting device capabilities
Expand Down
11 changes: 11 additions & 0 deletions taichi/backends/metal/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ namespace lang {
namespace metal {
namespace {

class FieldImpl : public aot::Field {
public:
explicit FieldImpl(KernelManager *runtime, const CompiledFieldData &field)
: runtime_(runtime), field_(field) {
}

private:
KernelManager *const runtime_;
CompiledFieldData field_;
};

class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(KernelManager *runtime, const std::string &kernel_name)
Expand Down
29 changes: 28 additions & 1 deletion taichi/backends/vulkan/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ namespace {

using KernelHandle = VkRuntime::KernelHandle;

class FieldImpl : public aot::Field {
public:
explicit FieldImpl(VkRuntime *runtime, const aot::CompiledFieldData &field)
: runtime_(runtime), field_(field) {
}

private:
VkRuntime *const runtime_;
aot::CompiledFieldData field_;
};

class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(VkRuntime *runtime, KernelHandle handle)
Expand Down Expand Up @@ -53,7 +64,12 @@ class AotModuleImpl : public aot::Module {
}

std::unique_ptr<aot::Field> get_field(const std::string &name) override {
TI_NOT_IMPLEMENTED;
aot::CompiledFieldData field;
if (!get_field_data_by_name(name, field)) {
TI_DEBUG("Failed to load field {}", name);
return nullptr;
}
return std::make_unique<FieldImpl>(runtime_, field);
}

size_t get_root_size() const override {
Expand All @@ -69,6 +85,17 @@ class AotModuleImpl : public aot::Module {
}

private:
bool get_field_data_by_name(const std::string &name,
aot::CompiledFieldData &field) {
for (int i = 0; i < ti_aot_data_.fields.size(); ++i) {
if (ti_aot_data_.fields[i].field_name.rfind(name, 0) == 0) {
field = ti_aot_data_.fields[i];
return true;
}
}
return false;
}

bool get_kernel_params_by_name(const std::string &name,
VkRuntime::RegisterParams &kernel) {
for (int i = 0; i < ti_aot_data_.kernels.size(); ++i) {
Expand Down
1 change: 1 addition & 0 deletions taichi/backends/vulkan/aot_module_loader_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct AotModuleParams {
};

std::unique_ptr<aot::Module> make_aot_module(std::any mod_params);

} // namespace vulkan
} // namespace lang
} // namespace taichi
25 changes: 21 additions & 4 deletions taichi/backends/vulkan/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,15 +591,32 @@ void VkRuntime::add_root_buffer(size_t root_buffer_size) {
root_buffer_size = 4; // there might be empty roots
}
std::unique_ptr<DeviceAllocationGuard> new_buffer =
device_->allocate_memory_unique(
{root_buffer_size,
/*host_write=*/false, /*host_read=*/false,
/*export_sharing=*/false, AllocUsage::Storage});
device_->allocate_memory_unique({root_buffer_size,
/*host_write=*/true, /*host_read=*/true,
/*export_sharing=*/false,
AllocUsage::Storage});
Stream *stream = device_->get_compute_stream();
auto cmdlist = stream->new_command_list();
cmdlist->buffer_fill(new_buffer->get_ptr(0), root_buffer_size, /*data=*/0);
stream->submit_synced(cmdlist.get());
root_buffers_.push_back(std::move(new_buffer));
// cache the root buffer size
root_buffers_size_map_[root_buffers_.back().get()] = root_buffer_size;
}

DeviceAllocation *VkRuntime::get_root_buffer(int id) const {
if (id >= root_buffers_.size()) {
TI_ERROR("root buffer id {} not found", id);
}
return root_buffers_[id].get();
}

size_t VkRuntime::get_root_buffer_size(int id) const {
auto it = root_buffers_size_map_.find(root_buffers_[id].get());
if (id >= root_buffers_.size() || it == root_buffers_size_map_.end()) {
TI_ERROR("root buffer id {} not found", id);
}
return it->second;
}

VkRuntime::RegisterParams run_codegen(
Expand Down
6 changes: 6 additions & 0 deletions taichi/backends/vulkan/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class TI_DLL_EXPORT VkRuntime {

void add_root_buffer(size_t root_buffer_size);

DeviceAllocation *get_root_buffer(int id) const;

size_t get_root_buffer_size(int id) const;

private:
friend class taichi::lang::vulkan::SNodeTreeManager;

Expand All @@ -125,6 +129,8 @@ class TI_DLL_EXPORT VkRuntime {
high_res_clock::time_point current_cmdlist_pending_since_;

std::vector<std::unique_ptr<CompiledTaichiKernel>> ti_kernels_;

std::unordered_map<DeviceAllocation *, size_t> root_buffers_size_map_;
};

VkRuntime::RegisterParams run_codegen(
Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/aot/aot_save_load_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ TEST(AotSaveLoad, Vulkan) {
aot::Module::load(".", Arch::vulkan, mod_params);
EXPECT_TRUE(vk_module);

// Retrieve kernels/fields/etc from AOT module to initialize runtime
// Retrieve kernels/fields/etc from AOT module
auto root_size = vk_module->get_root_size();
EXPECT_EQ(root_size, 64);
vulkan_runtime->add_root_buffer(root_size);
Expand All @@ -146,8 +146,8 @@ TEST(AotSaveLoad, Vulkan) {
ret_kernel->launch(&host_ctx);
vulkan_runtime->synchronize();

// auto x_field = vk_module.get_field("x");
// EXPECT_TRUE(x_field);
// x_field.copy_to(/*dst=*/x.get());
// Retrieve data
auto x_field = vk_module->get_field("place");
EXPECT_NE(x_field, nullptr);
}
#endif

0 comments on commit 570c39d

Please sign in to comment.