Skip to content

Commit

Permalink
[llvm] [aot] Add LLVM to CAPI part 6: Handle Field initialization in …
Browse files Browse the repository at this point in the history
…C-API (#5444)

* [llvm] [aot] Add LLVM to CAPI part 6: Handle Field initialization in C-API

* Renamed get_field() to get_snode_tree()
  • Loading branch information
jim19930609 authored Jul 19, 2022
1 parent 4bc6f0c commit b38a577
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 17 deletions.
14 changes: 14 additions & 0 deletions c_api/src/taichi_llvm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "taichi/program/compile_config.h"
#include "taichi/runtime/llvm/llvm_runtime_executor.h"
#include "taichi/runtime/llvm/llvm_aot_module_loader.h"
#include "taichi/runtime/cpu/aot_module_loader_impl.h"

#ifdef TI_WITH_CUDA
Expand Down Expand Up @@ -89,6 +90,19 @@ TiAotModule LlvmRuntime::load_aot_module(const char *module_path) {
#endif
}

/* TODO(zhanlue): expose allocate/deallocate_snode_tree_type() to C-API
Let's initialize SNodeTrees automatically for now since SNodeTreeType isn't
ready yet.
*/
auto *llvm_aot_module =
dynamic_cast<taichi::lang::LlvmAotModule *>(aot_module.get());
TI_ASSERT(llvm_aot_module != nullptr);
for (size_t i = 0; i < llvm_aot_module->get_num_snode_trees(); i++) {
auto *snode_tree = aot_module->get_snode_tree(std::to_string(i));
taichi::lang::allocate_aot_snode_tree_type(aot_module.get(), snode_tree,
this->result_buffer);
}

// Insert LLVMRuntime to RuntimeContext
executor_->prepare_runtime_context(&this->runtime_context_);
return (TiAotModule)(new AotModule(*this, std::move(aot_module)));
Expand Down
2 changes: 1 addition & 1 deletion taichi/aot/module_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ KernelTemplate *Module::get_kernel_template(const std::string &name) {
return kt_ptr;
}

Field *Module::get_field(const std::string &name) {
Field *Module::get_snode_tree(const std::string &name) {
auto itr = loaded_fields_.find(name);
if (itr != loaded_fields_.end()) {
return itr->second.get();
Expand Down
2 changes: 1 addition & 1 deletion taichi/aot/module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class TI_DLL_EXPORT Module {

Kernel *get_kernel(const std::string &name);
KernelTemplate *get_kernel_template(const std::string &name);
Field *get_field(const std::string &name);
Field *get_snode_tree(const std::string &name);

virtual std::unique_ptr<aot::CompiledGraph> get_graph(std::string name) {
TI_NOT_IMPLEMENTED;
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/llvm/aot_graph_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class FieldImpl : public aot::Field {
: field_(std::move(field)) {
}

LlvmOfflineCache::FieldCacheData get_field() const {
LlvmOfflineCache::FieldCacheData get_snode_tree_cache() const {
return field_;
}

Expand Down
8 changes: 4 additions & 4 deletions taichi/runtime/llvm/llvm_aot_module_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@ std::unique_ptr<aot::CompiledGraph> LlvmAotModule::get_graph(std::string name) {
return std::make_unique<aot::CompiledGraph>(std::move(graph));
}

void finalize_aot_field(aot::Module *aot_module,
aot::Field *aot_field,
uint64 *result_buffer) {
void allocate_aot_snode_tree_type(aot::Module *aot_module,
aot::Field *aot_field,
uint64 *result_buffer) {
auto *llvm_aot_module = dynamic_cast<LlvmAotModule *>(aot_module);
auto *aot_field_impl = dynamic_cast<llvm_aot::FieldImpl *>(aot_field);

TI_ASSERT(llvm_aot_module != nullptr);
TI_ASSERT(aot_field_impl != nullptr);

auto *runtime_executor = llvm_aot_module->get_runtime_executor();
const auto &field_cache = aot_field_impl->get_field();
const auto &field_cache = aot_field_impl->get_snode_tree_cache();

int snode_tree_id = field_cache.tree_id;
if (!llvm_aot_module->is_snode_tree_initialized(snode_tree_id)) {
Expand Down
19 changes: 16 additions & 3 deletions taichi/runtime/llvm/llvm_aot_module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
namespace taichi {
namespace lang {

TI_DLL_EXPORT void finalize_aot_field(aot::Module *aot_module,
aot::Field *aot_field,
uint64 *result_buffer);
/* TODO(zhanlue) refactor this interface once SNodeTreeType is available
The "aot::Field" created by "make_new_field()" is a SNodeTree in essense.
Therefore we're actually initializing the entire SNodeTree.
*/
TI_DLL_EXPORT void allocate_aot_snode_tree_type(aot::Module *aot_module,
aot::Field *aot_field,
uint64 *result_buffer);

class LlvmAotModule : public aot::Module {
public:
Expand Down Expand Up @@ -38,6 +42,10 @@ class LlvmAotModule : public aot::Module {
return executor_;
}

size_t get_num_snode_trees() {
return cache_reader_->get_num_snode_trees();
}

void set_initialized_snode_tree(int snode_tree_id) {
initialized_snode_tree_ids.insert(snode_tree_id);
}
Expand All @@ -59,6 +67,11 @@ class LlvmAotModule : public aot::Module {
std::unique_ptr<aot::Kernel> make_new_kernel(
const std::string &name) override;

/* TODO(zhanlue): replace "make_new_field()" with "make_snode_tree()" once
SNodeTreeType is available Field is not a standalone data structure - it is
essentially part of a SNodeTree object. User should always operate on a
"SNodeTree" instead of a "Field".
*/
std::unique_ptr<aot::Field> make_new_field(const std::string &name) override;

LlvmRuntimeExecutor *const executor_{nullptr};
Expand Down
4 changes: 4 additions & 0 deletions taichi/runtime/llvm/llvm_offline_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ LlvmOfflineCacheFileReader::LlvmOfflineCacheFileReader(
: path_(path), data_(std::move(data)), format_(format) {
}

size_t LlvmOfflineCacheFileReader::get_num_snode_trees() {
return data_.fields.size();
}

bool LlvmOfflineCacheFileReader::get_field_cache(
LlvmOfflineCache::FieldCacheData &res,
int snode_tree_id) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/runtime/llvm/llvm_offline_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class LlvmOfflineCacheFileReader {
bool get_field_cache(LlvmOfflineCache::FieldCacheData &res,
int snode_tree_id);

size_t get_num_snode_trees();

static std::unique_ptr<LlvmOfflineCacheFileReader> make(
const std::string &path,
LlvmOfflineCache::Format format = LlvmOfflineCache::Format::LL);
Expand Down
9 changes: 3 additions & 6 deletions tests/cpp/aot/llvm/field_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,9 @@ void run_field_tests(aot::Module *mod,
aot::Kernel *k_check_activate_pointer_fields =
mod->get_kernel("check_activate_pointer_fields");

// Initialize Fields
aot::Field *field_x = mod->get_field("0" /*snode_tree_id*/);
aot::Field *field_y = mod->get_field("0" /*snode_tree_id*/);

finalize_aot_field(mod, field_x, result_buffer);
finalize_aot_field(mod, field_y, result_buffer);
// Initialize SNodeTree
aot::Field *snode_tree_0 = mod->get_snode_tree("0" /*snode_tree_id*/);
allocate_aot_snode_tree_type(mod, snode_tree_0, result_buffer);

int base_value = 10;
/* -------- Test Case 1 ------ */
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/aot/vulkan/aot_save_load_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ TEST(AotSaveLoad, Vulkan) {
vulkan_runtime->synchronize();

// Retrieve data
auto x_field = vk_module->get_field("place");
auto x_field = vk_module->get_snode_tree("place");
EXPECT_NE(x_field, nullptr);
}

Expand Down

0 comments on commit b38a577

Please sign in to comment.