From 7627bb99fa48300b942748b7afcc0b9acb7bb198 Mon Sep 17 00:00:00 2001
From: Zhanlue Yang <jim19930609@gmail.com>
Date: Wed, 15 Mar 2023 16:14:48 +0800
Subject: [PATCH] [lang] Refactor CudaCachingAllocator into a more generic
 caching allocator (#7531)

Issue: #7300 .

### Brief Summary

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 taichi/rhi/cuda/CMakeLists.txt                |  1 -
 taichi/rhi/cuda/cuda_caching_allocator.cpp    | 38 ------------
 taichi/rhi/cuda/cuda_device.cpp               |  4 +-
 taichi/rhi/cuda/cuda_device.h                 |  4 +-
 taichi/rhi/llvm/CMakeLists.txt                |  1 +
 taichi/rhi/llvm/allocator.cpp                 | 62 +++++++++++++++++++
 .../allocator.h}                              | 12 ++--
 tests/python/test_ndarray.py                  |  2 +-
 8 files changed, 75 insertions(+), 49 deletions(-)
 delete mode 100644 taichi/rhi/cuda/cuda_caching_allocator.cpp
 create mode 100644 taichi/rhi/llvm/allocator.cpp
 rename taichi/rhi/{cuda/cuda_caching_allocator.h => llvm/allocator.h} (62%)

diff --git a/taichi/rhi/cuda/CMakeLists.txt b/taichi/rhi/cuda/CMakeLists.txt
index 1287602a2172b..0e1fdca52ff16 100644
--- a/taichi/rhi/cuda/CMakeLists.txt
+++ b/taichi/rhi/cuda/CMakeLists.txt
@@ -5,7 +5,6 @@ add_library(${CUDA_RHI})
 target_sources(${CUDA_RHI}
   PRIVATE
     cuda_device.cpp
-    cuda_caching_allocator.cpp
     cuda_context.cpp
     cuda_driver.cpp
     cuda_profiler.cpp
diff --git a/taichi/rhi/cuda/cuda_caching_allocator.cpp b/taichi/rhi/cuda/cuda_caching_allocator.cpp
deleted file mode 100644
index 7f1fab4384edb..0000000000000
--- a/taichi/rhi/cuda/cuda_caching_allocator.cpp
+++ /dev/null
@@ -1,38 +0,0 @@
-#include "taichi/rhi/cuda/cuda_caching_allocator.h"
-
-namespace taichi::lang {
-namespace cuda {
-
-CudaCachingAllocator::CudaCachingAllocator(LlvmDevice *device)
-    : device_(device) {
-}
-
-uint64_t *CudaCachingAllocator::allocate(
-    const LlvmDevice::LlvmRuntimeAllocParams &params) {
-  uint64_t *ret{nullptr};
-  auto size_aligned = taichi::iroundup(params.size, taichi_page_size);
-  auto it_blk = mem_blocks_.lower_bound(size_aligned);
-
-  if (it_blk != mem_blocks_.end()) {
-    size_t remaining_sz = it_blk->first - size_aligned;
-    if (remaining_sz > 0) {
-      TI_ASSERT(remaining_sz % taichi_page_size == 0);
-      auto remaining_head =
-          reinterpret_cast<uint8_t *>(it_blk->second) + size_aligned;
-      mem_blocks_.insert(
-          {remaining_sz, reinterpret_cast<uint64_t *>(remaining_head)});
-    }
-    ret = it_blk->second;
-    mem_blocks_.erase(it_blk);
-  } else {
-    ret = device_->allocate_llvm_runtime_memory_jit(params);
-  }
-  return ret;
-}
-
-void CudaCachingAllocator::release(size_t sz, uint64_t *ptr) {
-  mem_blocks_.insert({sz, ptr});
-}
-
-}  // namespace cuda
-}  // namespace taichi::lang
diff --git a/taichi/rhi/cuda/cuda_device.cpp b/taichi/rhi/cuda/cuda_device.cpp
index 07281ee442db5..0a59f5bbf84ab 100644
--- a/taichi/rhi/cuda/cuda_device.cpp
+++ b/taichi/rhi/cuda/cuda_device.cpp
@@ -44,7 +44,7 @@ DeviceAllocation CudaDevice::allocate_memory_runtime(
   info.size = taichi::iroundup(params.size, taichi_page_size);
   if (params.use_cached) {
     if (caching_allocator_ == nullptr) {
-      caching_allocator_ = std::make_unique<CudaCachingAllocator>(this);
+      caching_allocator_ = std::make_unique<CachingAllocator>(this);
     }
     info.ptr = caching_allocator_->allocate(params);
     CUDADriver::get_instance().memset((void *)info.ptr, 0, info.size);
@@ -72,7 +72,7 @@ void CudaDevice::dealloc_memory(DeviceAllocation handle) {
   TI_ASSERT(!info.is_imported);
   if (info.use_cached) {
     if (caching_allocator_ == nullptr) {
-      TI_ERROR("the CudaCachingAllocator is not initialized");
+      TI_ERROR("the CachingAllocator is not initialized");
     }
     caching_allocator_->release(info.size, (uint64_t *)info.ptr);
   } else if (!info.use_preallocated) {
diff --git a/taichi/rhi/cuda/cuda_device.h b/taichi/rhi/cuda/cuda_device.h
index 59e039ac84a50..c1e0185af879b 100644
--- a/taichi/rhi/cuda/cuda_device.h
+++ b/taichi/rhi/cuda/cuda_device.h
@@ -4,7 +4,7 @@
 
 #include "taichi/common/core.h"
 #include "taichi/rhi/cuda/cuda_driver.h"
-#include "taichi/rhi/cuda/cuda_caching_allocator.h"
+#include "taichi/rhi/llvm/allocator.h"
 #include "taichi/rhi/cuda/cuda_context.h"
 #include "taichi/rhi/llvm/llvm_device.h"
 
@@ -136,7 +136,7 @@ class CudaDevice : public LlvmDevice {
       TI_ERROR("invalid DeviceAllocation");
     }
   }
-  std::unique_ptr<CudaCachingAllocator> caching_allocator_{nullptr};
+  std::unique_ptr<CachingAllocator> caching_allocator_{nullptr};
 };
 
 }  // namespace cuda
diff --git a/taichi/rhi/llvm/CMakeLists.txt b/taichi/rhi/llvm/CMakeLists.txt
index 64bf1be4dbb1f..588e9707a8494 100644
--- a/taichi/rhi/llvm/CMakeLists.txt
+++ b/taichi/rhi/llvm/CMakeLists.txt
@@ -5,6 +5,7 @@ add_library(${LLVM_RHI})
 target_sources(${LLVM_RHI}
   PRIVATE
     llvm_device.cpp
+    allocator.cpp
   )
 
 target_include_directories(${LLVM_RHI}
diff --git a/taichi/rhi/llvm/allocator.cpp b/taichi/rhi/llvm/allocator.cpp
new file mode 100644
index 0000000000000..89091871afbe8
--- /dev/null
+++ b/taichi/rhi/llvm/allocator.cpp
@@ -0,0 +1,62 @@
+#include "taichi/rhi/llvm/allocator.h"
+#include "taichi/runtime/llvm/snode_tree_buffer_manager.h"
+
+namespace taichi::lang {
+
+CachingAllocator::CachingAllocator(LlvmDevice *device) : device_(device) {
+}
+
+void CachingAllocator::merge_and_insert(uint8_t *ptr, std::size_t size) {
+  // merge with right block
+  if (ptr_map_[ptr + size]) {
+    std::size_t tmp = ptr_map_[ptr + size];
+    mem_blocks_.erase(std::make_pair(tmp, ptr + size));
+    ptr_map_.erase(ptr + size);
+    size += tmp;
+  }
+  // merge with left block
+  auto map_it = ptr_map_.lower_bound(ptr);
+  if (map_it != ptr_map_.begin()) {
+    auto x = *--map_it;
+    if (x.first + x.second == ptr) {
+      mem_blocks_.erase(std::make_pair(x.second, x.first));
+      ptr_map_.erase(x.first);
+      ptr = x.first;
+      size += x.second;
+    }
+  }
+  mem_blocks_.insert(std::make_pair(size, ptr));
+  ptr_map_[ptr] = size;
+}
+
+uint64_t *CachingAllocator::allocate(
+    const LlvmDevice::LlvmRuntimeAllocParams &params) {
+  uint64_t *ret{nullptr};
+  auto size_aligned = taichi::iroundup(params.size, taichi_page_size);
+  auto it_blk = mem_blocks_.lower_bound(std::make_pair(size_aligned, nullptr));
+
+  if (it_blk != mem_blocks_.end()) {
+    size_t remaining_sz = it_blk->first - size_aligned;
+    if (remaining_sz > 0) {
+      TI_ASSERT(remaining_sz % taichi_page_size == 0);
+      auto remaining_head =
+          reinterpret_cast<uint8_t *>(it_blk->second) + size_aligned;
+      mem_blocks_.insert(std::make_pair(remaining_sz, remaining_head));
+      ptr_map_.insert(std::make_pair(remaining_head, remaining_sz));
+    }
+    ret = reinterpret_cast<uint64_t *>(it_blk->second);
+    mem_blocks_.erase(it_blk);
+    ptr_map_.erase(it_blk->second);
+
+  } else {
+    ret = reinterpret_cast<uint64_t *>(
+        device_->allocate_llvm_runtime_memory_jit(params));
+  }
+  return ret;
+}
+
+void CachingAllocator::release(size_t sz, uint64_t *ptr) {
+  merge_and_insert(reinterpret_cast<uint8_t *>(ptr), sz);
+}
+
+}  // namespace taichi::lang
diff --git a/taichi/rhi/cuda/cuda_caching_allocator.h b/taichi/rhi/llvm/allocator.h
similarity index 62%
rename from taichi/rhi/cuda/cuda_caching_allocator.h
rename to taichi/rhi/llvm/allocator.h
index e86aac68fd7f4..9dd1263913c91 100644
--- a/taichi/rhi/cuda/cuda_caching_allocator.h
+++ b/taichi/rhi/llvm/allocator.h
@@ -6,21 +6,23 @@
 #include "taichi/inc/constants.h"
 #include <stdint.h>
 #include <map>
+#include <set>
 
 namespace taichi::lang {
-namespace cuda {
 
-class CudaCachingAllocator {
+class CachingAllocator {
  public:
-  explicit CudaCachingAllocator(LlvmDevice *device);
+  explicit CachingAllocator(LlvmDevice *device);
 
   uint64_t *allocate(const LlvmDevice::LlvmRuntimeAllocParams &params);
   void release(size_t sz, uint64_t *ptr);
 
  private:
-  std::multimap<size_t, uint64_t *> mem_blocks_;
+  void merge_and_insert(uint8_t *ptr, std::size_t size);
+
+  std::set<std::pair<std::size_t, uint8_t *>> mem_blocks_;
+  std::map<uint8_t *, std::size_t> ptr_map_;
   LlvmDevice *device_{nullptr};
 };
 
-}  // namespace cuda
 }  // namespace taichi::lang
diff --git a/tests/python/test_ndarray.py b/tests/python/test_ndarray.py
index e6ba13be0eed1..a0e556f4c52eb 100644
--- a/tests/python/test_ndarray.py
+++ b/tests/python/test_ndarray.py
@@ -273,7 +273,7 @@ def test_ndarray_deepcopy():
 
 
 @test_utils.test(arch=[ti.cuda], ndarray_use_cached_allocator=True)
-def test_ndarray_cuda_caching_allocator():
+def test_ndarray_caching_allocator():
     n = 8
     a = ti.ndarray(ti.i32, shape=(n))
     a.fill(2)