Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Refactor allocation logic for SNodeTreeBufferManager #7795

Merged
merged 3 commits into from
Apr 19, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[lang] Refactor allocation logic for SNodeTreeBufferManager
jim19930609 authored and taichi-gardener committed Apr 19, 2023
commit 2b9c6ec13e865778073f6b0bb22adf4891e14ce2
7 changes: 4 additions & 3 deletions taichi/runtime/llvm/llvm_runtime_executor.cpp
Original file line number Diff line number Diff line change
@@ -394,12 +394,13 @@ void LlvmRuntimeExecutor::initialize_llvm_runtime_snodes(
const int tree_id = field_cache_data.tree_id;
const int root_id = field_cache_data.root_id;

root_size =
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
std::max(root_size, (size_t)taichi_page_size); // minimal allocation size
TI_TRACE("Allocating data structure of size {} bytes", root_size);
std::size_t rounded_size = taichi::iroundup(root_size, taichi_page_size);

Ptr root_buffer = snode_tree_buffer_manager_->allocate(
runtime_jit, llvm_runtime_, rounded_size, taichi_page_size, tree_id,
result_buffer);
Ptr root_buffer = snode_tree_buffer_manager_->allocate(rounded_size, tree_id,
result_buffer);
if (config_.arch == Arch::cuda) {
#if defined(TI_WITH_CUDA)
CUDADriver::get_instance().memset(root_buffer, 0, rounded_size);
70 changes: 7 additions & 63 deletions taichi/runtime/llvm/snode_tree_buffer_manager.cpp
Original file line number Diff line number Diff line change
@@ -9,74 +9,18 @@ SNodeTreeBufferManager::SNodeTreeBufferManager(
TI_TRACE("SNode tree buffer manager created.");
}

void SNodeTreeBufferManager::merge_and_insert(Ptr ptr, std::size_t size) {
// merge with right block
if (ptr_map_[ptr + size]) {
std::size_t tmp = ptr_map_[ptr + size];
size_set_.erase(std::make_pair(tmp, ptr + size));
ptr_map_.erase(ptr + size);
size += tmp;
}
// merge with left block
auto map_it = ptr_map_.lower_bound(ptr);
if (map_it != ptr_map_.begin()) {
auto x = *--map_it;
if (x.first + x.second == ptr) {
size_set_.erase(std::make_pair(x.second, x.first));
ptr_map_.erase(x.first);
ptr = x.first;
size += x.second;
}
}
size_set_.insert(std::make_pair(size, ptr));
ptr_map_[ptr] = size;
}

Ptr SNodeTreeBufferManager::allocate(JITModule *runtime_jit,
void *runtime,
std::size_t size,
std::size_t alignment,
Ptr SNodeTreeBufferManager::allocate(std::size_t size,
const int snode_tree_id,
uint64 *result_buffer) {
TI_TRACE("allocating memory for SNode Tree {}", snode_tree_id);
TI_ASSERT_INFO(snode_tree_id < kMaxNumSnodeTreesLlvm,
"LLVM backend supports up to {} snode trees",
kMaxNumSnodeTreesLlvm);
auto set_it = size_set_.lower_bound(std::make_pair(size, nullptr));
if (set_it == size_set_.end()) {
runtime_jit->call<void *, std::size_t, std::size_t>(
"runtime_memory_allocate_aligned", runtime, size, alignment,
result_buffer);
auto ptr = runtime_exec_->fetch_result<Ptr>(0, result_buffer);
roots_[snode_tree_id] = ptr;
sizes_[snode_tree_id] = size;
return ptr;
} else {
auto x = *set_it;
size_set_.erase(x);
ptr_map_.erase(x.second);
if (x.first - size > 0) {
size_set_.insert(std::make_pair(x.first - size, x.second + size));
ptr_map_[x.second + size] = x.first - size;
}
TI_ASSERT(x.second);
roots_[snode_tree_id] = x.second;
sizes_[snode_tree_id] = size;
return x.second;
}
auto devalloc = runtime_exec_->allocate_memory_ndarray(size, result_buffer);
snode_tree_id_to_device_alloc_[snode_tree_id] = devalloc;
return (Ptr)runtime_exec_->get_ndarray_alloc_info_ptr(devalloc);
}

void SNodeTreeBufferManager::destroy(SNodeTree *snode_tree) {
int snode_tree_id = snode_tree->id();
TI_TRACE("Destroying SNode tree {}.", snode_tree_id);
std::size_t size = sizes_[snode_tree_id];
if (size == 0) {
TI_DEBUG("SNode tree {} destroy failed.", snode_tree_id);
return;
}
Ptr ptr = roots_[snode_tree_id];
merge_and_insert(ptr, size);
TI_DEBUG("SNode tree {} destroyed.", snode_tree_id);
auto devalloc = snode_tree_id_to_device_alloc_[snode_tree->id()];
runtime_exec_->deallocate_memory_ndarray(devalloc);
snode_tree_id_to_device_alloc_.erase(snode_tree->id());
}

} // namespace taichi::lang
13 changes: 3 additions & 10 deletions taichi/runtime/llvm/snode_tree_buffer_manager.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once
#include "taichi/inc/constants.h"
#include "taichi/struct/snode_tree.h"
#include "taichi/rhi/public_device.h"
#define TI_RUNTIME_HOST

#include <set>
@@ -16,23 +17,15 @@ class SNodeTreeBufferManager {
public:
explicit SNodeTreeBufferManager(LlvmRuntimeExecutor *runtime_exec);

void merge_and_insert(Ptr ptr, std::size_t size);

Ptr allocate(JITModule *runtime_jit,
void *runtime,
std::size_t size,
std::size_t alignment,
Ptr allocate(std::size_t size,
const int snode_tree_id,
uint64 *result_buffer);

void destroy(SNodeTree *snode_tree);

private:
std::set<std::pair<std::size_t, Ptr>> size_set_;
std::map<Ptr, std::size_t> ptr_map_;
LlvmRuntimeExecutor *runtime_exec_;
Ptr roots_[kMaxNumSnodeTreesLlvm];
std::size_t sizes_[kMaxNumSnodeTreesLlvm];
std::map<int, DeviceAllocation> snode_tree_id_to_device_alloc_;
};

} // namespace taichi::lang