diff --git a/src/compute_cluster_size_ext.cpp b/src/compute_cluster_size_ext.cpp index 93782a45196..56bace7ccf3 100644 --- a/src/compute_cluster_size_ext.cpp +++ b/src/compute_cluster_size_ext.cpp @@ -57,15 +57,15 @@ ComputeClusterSizeExt::ComputeClusterSizeExt(LAMMPS* lmp, int narg, char** arg) size_cutoff = utils::inumeric(FLERR, arg[4], true, lmp); if (size_cutoff < 1) { error->all(FLERR, "size_cutoff for compute cluster/size must be greater than 0"); } - keeper1 = new MemoryKeeper>(memory); + keeper1 = new MemoryKeeper(memory); cluster_map_allocator = new MapAlloc_t(keeper1); cluster_map = new Map_t(*cluster_map_allocator); - keeper2 = new MemoryKeeper>>(memory); + keeper2 = new MemoryKeeper(memory); alloc_map_vec1 = new MapAlloc_t>(keeper2); cIDs_by_size = new Map_t>(*alloc_map_vec1); - keeper3 = new MemoryKeeper>>(memory); + keeper3 = new MemoryKeeper(memory); alloc_map_vec2 = new MapAlloc_t>(keeper3); cIDs_by_size_all = new Map_t>(*alloc_map_vec2); @@ -112,9 +112,9 @@ void ComputeClusterSizeExt::init() vector = dist.data(); nloc = static_cast(atom->nlocal * LMP_NUCC_ALLOC_COEFF); - keeper1->pool_size(nloc); - keeper2->pool_size(nloc); - keeper3->pool_size(nloc); + keeper1->pool_size>(nloc); + keeper2->pool_size>>(nloc); + keeper3->pool_size>>(nloc); cluster_map->reserve(nloc); cIDs_by_size->reserve(size_cutoff); diff --git a/src/compute_cluster_size_ext.h b/src/compute_cluster_size_ext.h index 8fbd5634d01..c1b5cf14d1d 100644 --- a/src/compute_cluster_size_ext.h +++ b/src/compute_cluster_size_ext.h @@ -101,22 +101,22 @@ class ComputeClusterSizeExt : public Compute { inline constexpr const NUCC::Map_t *get_cluster_map() const noexcept(true) { return cluster_map; } inline constexpr const NUCC::Map_t> *get_cIDs_by_size() const noexcept(true) { return cIDs_by_size; } inline constexpr const NUCC::Map_t> *get_cIDs_by_size_all() const noexcept(true) { return cIDs_by_size_all; } - inline constexpr const NUCC::cspan &get_clusters() const noexcept(true) { return clusters; } + inline constexpr const NUCC::cspan get_clusters() const noexcept(true) { return clusters; } private: int size_cutoff; // number of elements reserved in dist - NUCC::MemoryKeeper> *keeper1; + NUCC::MemoryKeeper *keeper1; NUCC::MapAlloc_t *cluster_map_allocator; NUCC::Map_t *cluster_map; // std::unordered_map cluster_map; // clid -> idx - NUCC::MemoryKeeper>> *keeper2; + NUCC::MemoryKeeper *keeper2; NUCC::MapAlloc_t> *alloc_map_vec1; NUCC::Map_t> *cIDs_by_size; // std::unordered_map> cIDs_by_size; // size -> vector(idx) - NUCC::MemoryKeeper>> *keeper3; + NUCC::MemoryKeeper *keeper3; NUCC::MapAlloc_t> *alloc_map_vec2; NUCC::Map_t> *cIDs_by_size_all; // std::unordered_map> cIDs_by_size_all; diff --git a/src/nucc_allocator.hpp b/src/nucc_allocator.hpp index 8a34b79ed88..852fa6aeccd 100644 --- a/src/nucc_allocator.hpp +++ b/src/nucc_allocator.hpp @@ -11,7 +11,6 @@ namespace NUCC { -template class MemoryKeeper { public: MemoryKeeper() = delete; @@ -20,15 +19,16 @@ class MemoryKeeper { MemoryKeeper& operator=(const MemoryKeeper&) = delete; MemoryKeeper& operator=(MemoryKeeper&&) = delete; - constexpr MemoryKeeper(Memory* memory) noexcept : memory_(memory) {} + constexpr MemoryKeeper(LAMMPS_NS::Memory* memory) noexcept : memory_(memory) {} ~MemoryKeeper() noexcept(noexcept(clear())) { clear(); } + template void store(T*& ptr, const size_t size) noexcept(noexcept(infos.emplace_back(ptr, size))) { infos.emplace_back(ptr, size); } - void clear() noexcept(noexcept(std::declval().destroy(std::declval()))) + void clear() noexcept(noexcept(std::declval().destroy(std::declval()))) { for (auto& pool : infos) { memory_->destroy(pool.ptr); } } @@ -39,31 +39,40 @@ class MemoryKeeper { return _pool_size; } + template + inline constexpr void pool_size(std::size_t n) noexcept { _pool_size = n * sizeof(T); } + inline constexpr void pool_size(std::size_t n) noexcept { _pool_size = n; } + template T* allocate(const std::size_t n) { T* ptr = nullptr; - if (n > _pool_size) { + std::size_t pool_size_T = _pool_size / sizeof(T) + 1; + if (n > pool_size_T) { // If requested size is larger than pool, allocate separately memory_->create(ptr, n, "CustomAllocator_Large"); infos.emplace_back(ptr, n); return ptr; } - if (current == nullptr || left < n) { + std::size_t nbytes = n * sizeof(T); + T* _current = reinterpret_cast(current); + if ((current == nullptr) || (left < nbytes)) { // Pool is full or not initialized, request a new pool - memory_->create(current, _pool_size, "CustomAllocator_Pool"); - infos.emplace_back(current, _pool_size); + memory_->create(_current, pool_size_T, "CustomAllocator_Pool"); + infos.emplace_back(_current, pool_size_T); left = _pool_size; } - ptr = current; - left -= n; - current += n; + ptr = _current; + left -= nbytes; + _current += n; + current = reinterpret_cast(_current); return ptr; } private: struct PoolInfo { + template constexpr PoolInfo(T*& ptr, const size_t size) noexcept : ptr(reinterpret_cast(ptr)), size(size * sizeof(T)) { } @@ -74,9 +83,9 @@ class MemoryKeeper { size_t size = 0; }; - T* current = nullptr; + void* current = nullptr; std::size_t left = 0; - Memory* const memory_ = nullptr; + LAMMPS_NS::Memory* const memory_ = nullptr; std::size_t _pool_size = 0; std::vector infos; }; @@ -95,16 +104,16 @@ class CustomAllocator { CustomAllocator() = delete; - constexpr CustomAllocator(MemoryKeeper* const keeper) noexcept : keeper_(keeper) {} + constexpr CustomAllocator(MemoryKeeper* const keeper) noexcept : keeper_(keeper) {} template constexpr CustomAllocator(const CustomAllocator& other) noexcept : keeper_(other.keeper_) { } - inline T* allocate(const std::size_t n) const { return keeper_->allocate(n); } + inline T* allocate(std::size_t n) const { return keeper_->allocate(n); } - inline constexpr void deallocate(T /**p*/, const std::size_t /*n*/) const noexcept + inline constexpr void deallocate(T* /*p*/, const std::size_t /*n*/) const noexcept { // Deallocation can be handled when the allocator is destroyed // For pool allocator, individual deallocations are often no-ops @@ -124,7 +133,7 @@ class CustomAllocator { } private: - MemoryKeeper* const keeper_; + MemoryKeeper* const keeper_; }; } // namespace NUCC