Skip to content

Commit

Permalink
[lang] Refactor CudaCachingAllocator into a more generic caching allo…
Browse files Browse the repository at this point in the history
…cator (taichi-dev#7531)

Issue: taichi-dev#7300 .
 
### Brief Summary

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent cea3114 commit 143db3c
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 49 deletions.
1 change: 0 additions & 1 deletion taichi/rhi/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ add_library(${CUDA_RHI})
target_sources(${CUDA_RHI}
PRIVATE
cuda_device.cpp
cuda_caching_allocator.cpp
cuda_context.cpp
cuda_driver.cpp
cuda_profiler.cpp
Expand Down
38 changes: 0 additions & 38 deletions taichi/rhi/cuda/cuda_caching_allocator.cpp

This file was deleted.

4 changes: 2 additions & 2 deletions taichi/rhi/cuda/cuda_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ DeviceAllocation CudaDevice::allocate_memory_runtime(
info.size = taichi::iroundup(params.size, taichi_page_size);
if (params.use_cached) {
if (caching_allocator_ == nullptr) {
caching_allocator_ = std::make_unique<CudaCachingAllocator>(this);
caching_allocator_ = std::make_unique<CachingAllocator>(this);
}
info.ptr = caching_allocator_->allocate(params);
CUDADriver::get_instance().memset((void *)info.ptr, 0, info.size);
Expand Down Expand Up @@ -72,7 +72,7 @@ void CudaDevice::dealloc_memory(DeviceAllocation handle) {
TI_ASSERT(!info.is_imported);
if (info.use_cached) {
if (caching_allocator_ == nullptr) {
TI_ERROR("the CudaCachingAllocator is not initialized");
TI_ERROR("the CachingAllocator is not initialized");
}
caching_allocator_->release(info.size, (uint64_t *)info.ptr);
} else if (!info.use_preallocated) {
Expand Down
4 changes: 2 additions & 2 deletions taichi/rhi/cuda/cuda_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include "taichi/common/core.h"
#include "taichi/rhi/cuda/cuda_driver.h"
#include "taichi/rhi/cuda/cuda_caching_allocator.h"
#include "taichi/rhi/llvm/allocator.h"
#include "taichi/rhi/cuda/cuda_context.h"
#include "taichi/rhi/llvm/llvm_device.h"

Expand Down Expand Up @@ -136,7 +136,7 @@ class CudaDevice : public LlvmDevice {
TI_ERROR("invalid DeviceAllocation");
}
}
std::unique_ptr<CudaCachingAllocator> caching_allocator_{nullptr};
std::unique_ptr<CachingAllocator> caching_allocator_{nullptr};
};

} // namespace cuda
Expand Down
1 change: 1 addition & 0 deletions taichi/rhi/llvm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_library(${LLVM_RHI})
target_sources(${LLVM_RHI}
PRIVATE
llvm_device.cpp
allocator.cpp
)

target_include_directories(${LLVM_RHI}
Expand Down
62 changes: 62 additions & 0 deletions taichi/rhi/llvm/allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include "taichi/rhi/llvm/allocator.h"
#include "taichi/runtime/llvm/snode_tree_buffer_manager.h"

namespace taichi::lang {

CachingAllocator::CachingAllocator(LlvmDevice *device) : device_(device) {
}

void CachingAllocator::merge_and_insert(uint8_t *ptr, std::size_t size) {
// merge with right block
if (ptr_map_[ptr + size]) {
std::size_t tmp = ptr_map_[ptr + size];
mem_blocks_.erase(std::make_pair(tmp, ptr + size));
ptr_map_.erase(ptr + size);
size += tmp;
}
// merge with left block
auto map_it = ptr_map_.lower_bound(ptr);
if (map_it != ptr_map_.begin()) {
auto x = *--map_it;
if (x.first + x.second == ptr) {
mem_blocks_.erase(std::make_pair(x.second, x.first));
ptr_map_.erase(x.first);
ptr = x.first;
size += x.second;
}
}
mem_blocks_.insert(std::make_pair(size, ptr));
ptr_map_[ptr] = size;
}

uint64_t *CachingAllocator::allocate(
const LlvmDevice::LlvmRuntimeAllocParams &params) {
uint64_t *ret{nullptr};
auto size_aligned = taichi::iroundup(params.size, taichi_page_size);
auto it_blk = mem_blocks_.lower_bound(std::make_pair(size_aligned, nullptr));

if (it_blk != mem_blocks_.end()) {
size_t remaining_sz = it_blk->first - size_aligned;
if (remaining_sz > 0) {
TI_ASSERT(remaining_sz % taichi_page_size == 0);
auto remaining_head =
reinterpret_cast<uint8_t *>(it_blk->second) + size_aligned;
mem_blocks_.insert(std::make_pair(remaining_sz, remaining_head));
ptr_map_.insert(std::make_pair(remaining_head, remaining_sz));
}
ret = reinterpret_cast<uint64_t *>(it_blk->second);
mem_blocks_.erase(it_blk);
ptr_map_.erase(it_blk->second);

} else {
ret = reinterpret_cast<uint64_t *>(
device_->allocate_llvm_runtime_memory_jit(params));
}
return ret;
}

void CachingAllocator::release(size_t sz, uint64_t *ptr) {
merge_and_insert(reinterpret_cast<uint8_t *>(ptr), sz);
}

} // namespace taichi::lang
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,23 @@
#include "taichi/inc/constants.h"
#include <stdint.h>
#include <map>
#include <set>

namespace taichi::lang {
namespace cuda {

class CudaCachingAllocator {
class CachingAllocator {
public:
explicit CudaCachingAllocator(LlvmDevice *device);
explicit CachingAllocator(LlvmDevice *device);

uint64_t *allocate(const LlvmDevice::LlvmRuntimeAllocParams &params);
void release(size_t sz, uint64_t *ptr);

private:
std::multimap<size_t, uint64_t *> mem_blocks_;
void merge_and_insert(uint8_t *ptr, std::size_t size);

std::set<std::pair<std::size_t, uint8_t *>> mem_blocks_;
std::map<uint8_t *, std::size_t> ptr_map_;
LlvmDevice *device_{nullptr};
};

} // namespace cuda
} // namespace taichi::lang
2 changes: 1 addition & 1 deletion tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_ndarray_deepcopy():


@test_utils.test(arch=[ti.cuda], ndarray_use_cached_allocator=True)
def test_ndarray_cuda_caching_allocator():
def test_ndarray_caching_allocator():
n = 8
a = ti.ndarray(ti.i32, shape=(n))
a.fill(2)
Expand Down

0 comments on commit 143db3c

Please sign in to comment.