Skip to content

Commit

Permalink
[lang] Refactor allocation logic for SNodeTreeBufferManager
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 committed Apr 12, 2023
1 parent b9a9d95 commit b192299
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 61 deletions.
2 changes: 2 additions & 0 deletions taichi/runtime/llvm/llvm_runtime_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ void LlvmRuntimeExecutor::initialize_llvm_runtime_snodes(
const int root_id = field_cache_data.root_id;

TI_TRACE("Allocating data structure of size {} bytes", root_size);
root_size =
std::max(root_size, (size_t)taichi_page_size); // minimal allocation size
std::size_t rounded_size = taichi::iroundup(root_size, taichi_page_size);

Ptr root_buffer = snode_tree_buffer_manager_->allocate(
Expand Down
65 changes: 6 additions & 59 deletions taichi/runtime/llvm/snode_tree_buffer_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,74 +9,21 @@ SNodeTreeBufferManager::SNodeTreeBufferManager(
TI_TRACE("SNode tree buffer manager created.");
}

void SNodeTreeBufferManager::merge_and_insert(Ptr ptr, std::size_t size) {
// merge with right block
if (ptr_map_[ptr + size]) {
std::size_t tmp = ptr_map_[ptr + size];
size_set_.erase(std::make_pair(tmp, ptr + size));
ptr_map_.erase(ptr + size);
size += tmp;
}
// merge with left block
auto map_it = ptr_map_.lower_bound(ptr);
if (map_it != ptr_map_.begin()) {
auto x = *--map_it;
if (x.first + x.second == ptr) {
size_set_.erase(std::make_pair(x.second, x.first));
ptr_map_.erase(x.first);
ptr = x.first;
size += x.second;
}
}
size_set_.insert(std::make_pair(size, ptr));
ptr_map_[ptr] = size;
}

Ptr SNodeTreeBufferManager::allocate(JITModule *runtime_jit,
void *runtime,
std::size_t size,
std::size_t alignment,
const int snode_tree_id,
uint64 *result_buffer) {
TI_TRACE("allocating memory for SNode Tree {}", snode_tree_id);
TI_ASSERT_INFO(snode_tree_id < kMaxNumSnodeTreesLlvm,
"LLVM backend supports up to {} snode trees",
kMaxNumSnodeTreesLlvm);
auto set_it = size_set_.lower_bound(std::make_pair(size, nullptr));
if (set_it == size_set_.end()) {
runtime_jit->call<void *, std::size_t, std::size_t>(
"runtime_memory_allocate_aligned", runtime, size, alignment,
result_buffer);
auto ptr = runtime_exec_->fetch_result<Ptr>(0, result_buffer);
roots_[snode_tree_id] = ptr;
sizes_[snode_tree_id] = size;
return ptr;
} else {
auto x = *set_it;
size_set_.erase(x);
ptr_map_.erase(x.second);
if (x.first - size > 0) {
size_set_.insert(std::make_pair(x.first - size, x.second + size));
ptr_map_[x.second + size] = x.first - size;
}
TI_ASSERT(x.second);
roots_[snode_tree_id] = x.second;
sizes_[snode_tree_id] = size;
return x.second;
}
auto devalloc = runtime_exec_->allocate_memory_ndarray(size, result_buffer);
snode_tree_id_to_device_alloc_[snode_tree_id] = devalloc;
return (Ptr)runtime_exec_->get_ndarray_alloc_info_ptr(devalloc);
}

void SNodeTreeBufferManager::destroy(SNodeTree *snode_tree) {
int snode_tree_id = snode_tree->id();
TI_TRACE("Destroying SNode tree {}.", snode_tree_id);
std::size_t size = sizes_[snode_tree_id];
if (size == 0) {
TI_DEBUG("SNode tree {} destroy failed.", snode_tree_id);
return;
}
Ptr ptr = roots_[snode_tree_id];
merge_and_insert(ptr, size);
TI_DEBUG("SNode tree {} destroyed.", snode_tree_id);
auto devalloc = snode_tree_id_to_device_alloc_[snode_tree->id()];
runtime_exec_->deallocate_memory_ndarray(devalloc);
snode_tree_id_to_device_alloc_.erase(snode_tree->id());
}

} // namespace taichi::lang
5 changes: 3 additions & 2 deletions taichi/runtime/llvm/snode_tree_buffer_manager.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once
#include "taichi/inc/constants.h"
#include "taichi/struct/snode_tree.h"
#include "taichi/rhi/public_device.h"
#define TI_RUNTIME_HOST

#include <set>
Expand All @@ -16,8 +17,6 @@ class SNodeTreeBufferManager {
public:
explicit SNodeTreeBufferManager(LlvmRuntimeExecutor *runtime_exec);

void merge_and_insert(Ptr ptr, std::size_t size);

Ptr allocate(JITModule *runtime_jit,
void *runtime,
std::size_t size,
Expand All @@ -33,6 +32,8 @@ class SNodeTreeBufferManager {
LlvmRuntimeExecutor *runtime_exec_;
Ptr roots_[kMaxNumSnodeTreesLlvm];
std::size_t sizes_[kMaxNumSnodeTreesLlvm];

std::map<int, DeviceAllocation> snode_tree_id_to_device_alloc_;
};

} // namespace taichi::lang

0 comments on commit b192299

Please sign in to comment.