diff --git a/taichi/runtime/llvm/llvm_runtime_executor.cpp b/taichi/runtime/llvm/llvm_runtime_executor.cpp index 70e83c0a246e56..551a5bb2ea58f2 100644 --- a/taichi/runtime/llvm/llvm_runtime_executor.cpp +++ b/taichi/runtime/llvm/llvm_runtime_executor.cpp @@ -395,6 +395,8 @@ void LlvmRuntimeExecutor::initialize_llvm_runtime_snodes( const int root_id = field_cache_data.root_id; TI_TRACE("Allocating data structure of size {} bytes", root_size); + root_size = + std::max(root_size, (size_t)taichi_page_size); // minimal allocation size std::size_t rounded_size = taichi::iroundup(root_size, taichi_page_size); Ptr root_buffer = snode_tree_buffer_manager_->allocate( diff --git a/taichi/runtime/llvm/snode_tree_buffer_manager.cpp b/taichi/runtime/llvm/snode_tree_buffer_manager.cpp index 8498ddb736892b..63642ef33ef016 100644 --- a/taichi/runtime/llvm/snode_tree_buffer_manager.cpp +++ b/taichi/runtime/llvm/snode_tree_buffer_manager.cpp @@ -9,74 +9,21 @@ SNodeTreeBufferManager::SNodeTreeBufferManager( TI_TRACE("SNode tree buffer manager created."); } -void SNodeTreeBufferManager::merge_and_insert(Ptr ptr, std::size_t size) { - // merge with right block - if (ptr_map_[ptr + size]) { - std::size_t tmp = ptr_map_[ptr + size]; - size_set_.erase(std::make_pair(tmp, ptr + size)); - ptr_map_.erase(ptr + size); - size += tmp; - } - // merge with left block - auto map_it = ptr_map_.lower_bound(ptr); - if (map_it != ptr_map_.begin()) { - auto x = *--map_it; - if (x.first + x.second == ptr) { - size_set_.erase(std::make_pair(x.second, x.first)); - ptr_map_.erase(x.first); - ptr = x.first; - size += x.second; - } - } - size_set_.insert(std::make_pair(size, ptr)); - ptr_map_[ptr] = size; -} - Ptr SNodeTreeBufferManager::allocate(JITModule *runtime_jit, void *runtime, std::size_t size, std::size_t alignment, const int snode_tree_id, uint64 *result_buffer) { - TI_TRACE("allocating memory for SNode Tree {}", snode_tree_id); - TI_ASSERT_INFO(snode_tree_id < kMaxNumSnodeTreesLlvm, - "LLVM backend supports up to {} snode trees", - kMaxNumSnodeTreesLlvm); - auto set_it = size_set_.lower_bound(std::make_pair(size, nullptr)); - if (set_it == size_set_.end()) { - runtime_jit->call( - "runtime_memory_allocate_aligned", runtime, size, alignment, - result_buffer); - auto ptr = runtime_exec_->fetch_result(0, result_buffer); - roots_[snode_tree_id] = ptr; - sizes_[snode_tree_id] = size; - return ptr; - } else { - auto x = *set_it; - size_set_.erase(x); - ptr_map_.erase(x.second); - if (x.first - size > 0) { - size_set_.insert(std::make_pair(x.first - size, x.second + size)); - ptr_map_[x.second + size] = x.first - size; - } - TI_ASSERT(x.second); - roots_[snode_tree_id] = x.second; - sizes_[snode_tree_id] = size; - return x.second; - } + auto devalloc = runtime_exec_->allocate_memory_ndarray(size, result_buffer); + snode_tree_id_to_device_alloc_[snode_tree_id] = devalloc; + return (Ptr)runtime_exec_->get_ndarray_alloc_info_ptr(devalloc); } void SNodeTreeBufferManager::destroy(SNodeTree *snode_tree) { - int snode_tree_id = snode_tree->id(); - TI_TRACE("Destroying SNode tree {}.", snode_tree_id); - std::size_t size = sizes_[snode_tree_id]; - if (size == 0) { - TI_DEBUG("SNode tree {} destroy failed.", snode_tree_id); - return; - } - Ptr ptr = roots_[snode_tree_id]; - merge_and_insert(ptr, size); - TI_DEBUG("SNode tree {} destroyed.", snode_tree_id); + auto devalloc = snode_tree_id_to_device_alloc_[snode_tree->id()]; + runtime_exec_->deallocate_memory_ndarray(devalloc); + snode_tree_id_to_device_alloc_.erase(snode_tree->id()); } } // namespace taichi::lang diff --git a/taichi/runtime/llvm/snode_tree_buffer_manager.h b/taichi/runtime/llvm/snode_tree_buffer_manager.h index fd116ad84d0226..6c1bcb8d291773 100644 --- a/taichi/runtime/llvm/snode_tree_buffer_manager.h +++ b/taichi/runtime/llvm/snode_tree_buffer_manager.h @@ -1,6 +1,7 @@ #pragma once #include "taichi/inc/constants.h" #include "taichi/struct/snode_tree.h" +#include "taichi/rhi/public_device.h" #define TI_RUNTIME_HOST #include @@ -16,8 +17,6 @@ class SNodeTreeBufferManager { public: explicit SNodeTreeBufferManager(LlvmRuntimeExecutor *runtime_exec); - void merge_and_insert(Ptr ptr, std::size_t size); - Ptr allocate(JITModule *runtime_jit, void *runtime, std::size_t size, @@ -33,6 +32,8 @@ class SNodeTreeBufferManager { LlvmRuntimeExecutor *runtime_exec_; Ptr roots_[kMaxNumSnodeTreesLlvm]; std::size_t sizes_[kMaxNumSnodeTreesLlvm]; + + std::map snode_tree_id_to_device_alloc_; }; } // namespace taichi::lang