Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backport] thrust/mr: fix the case of reuising a block for a smaller alloc. (#1232) #1317

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions thrust/testing/mr_pool.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,26 @@ public:

virtual tracked_pointer<void> do_allocate(std::size_t n, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT) override
{
ASSERT_EQUAL(static_cast<bool>(id_to_allocate), true);
ASSERT_EQUAL(id_to_allocate || id_to_allocate == -1u, true);

void * raw = upstream.do_allocate(n, alignment);
tracked_pointer<void> ret(raw);
ret.id = id_to_allocate;
ret.size = n;
ret.alignment = alignment;

id_to_allocate = 0;
if (id_to_allocate != -1u)
{
id_to_allocate = 0;
}

return ret;
}

virtual void do_deallocate(tracked_pointer<void> p, std::size_t n, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT) override
{
ASSERT_EQUAL(p.size, n);
ASSERT_EQUAL(p.alignment, alignment);
ASSERT_GEQUAL(p.size, n);
ASSERT_GEQUAL(p.alignment, alignment);

if (id_to_deallocate != 0)
{
Expand Down Expand Up @@ -318,6 +321,36 @@ void TestPoolCachingOversized()
upstream.id_to_allocate = 7;
tracked_pointer<void> a9 = pool.do_allocate(2048, 32);
ASSERT_EQUAL(a9.id, 7u);

// make sure that reusing a larger oversized block for a smaller allocation works
// this is NVIDIA/cccl#585
upstream.id_to_allocate = 8;
tracked_pointer<void> a10 = pool.do_allocate(2048 + 16, THRUST_MR_DEFAULT_ALIGNMENT);
pool.do_deallocate(a10, 2048 + 16, THRUST_MR_DEFAULT_ALIGNMENT);
tracked_pointer<void> a11 = pool.do_allocate(2048, THRUST_MR_DEFAULT_ALIGNMENT);
ASSERT_EQUAL(a11.ptr, a10.ptr);
pool.do_deallocate(a11, 2048, THRUST_MR_DEFAULT_ALIGNMENT);

// original minimized reproducer from NVIDIA/cccl#585:
{
upstream.id_to_allocate = -1u;

auto ptr1 = pool.allocate(43920240);
auto ptr2 = pool.allocate(2465264);
pool.deallocate(ptr1, 43920240);
pool.deallocate(ptr2, 2465264);
auto ptr3 = pool.allocate(4930528);
pool.deallocate(ptr3, 4930528);
auto ptr4 = pool.allocate(14640080);
std::memset(thrust::raw_pointer_cast(ptr4), 0xff, 14640080);

auto crash = pool.allocate(4930528);

pool.deallocate(crash, 4930528);
pool.deallocate(ptr4, 14640080);

upstream.id_to_allocate = 0;
}
}

void TestUnsynchronizedPoolCachingOversized()
Expand Down
108 changes: 70 additions & 38 deletions thrust/thrust/mr/pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,17 @@ class unsynchronized_pool_resource final

private:
typedef typename Upstream::pointer void_ptr;
typedef typename thrust::detail::pointer_traits<void_ptr>::template rebind<char>::other char_ptr;
typedef thrust::detail::pointer_traits<void_ptr> void_ptr_traits;
typedef typename void_ptr_traits::template rebind<char>::other char_ptr;

struct block_descriptor;
struct chunk_descriptor;
struct oversized_block_descriptor;

typedef typename thrust::detail::pointer_traits<void_ptr>::template rebind<block_descriptor>::other block_descriptor_ptr;
typedef typename thrust::detail::pointer_traits<void_ptr>::template rebind<chunk_descriptor>::other chunk_descriptor_ptr;
typedef typename thrust::detail::pointer_traits<void_ptr>::template rebind<oversized_block_descriptor>::other oversized_block_descriptor_ptr;
typedef typename void_ptr_traits::template rebind<block_descriptor>::other block_descriptor_ptr;
typedef typename void_ptr_traits::template rebind<chunk_descriptor>::other chunk_descriptor_ptr;
typedef typename void_ptr_traits::template rebind<oversized_block_descriptor>::other oversized_block_descriptor_ptr;
typedef thrust::detail::pointer_traits<oversized_block_descriptor_ptr> oversized_block_ptr_traits;

struct block_descriptor
{
Expand Down Expand Up @@ -194,6 +196,7 @@ class unsynchronized_pool_resource final
oversized_block_descriptor_ptr prev;
oversized_block_descriptor_ptr next;
oversized_block_descriptor_ptr next_cached;
std::size_t current_size;
};

struct pool
Expand Down Expand Up @@ -244,17 +247,20 @@ class unsynchronized_pool_resource final
}

// deallocate cached oversized/overaligned memory
while (detail::pointer_traits<oversized_block_descriptor_ptr>::get(m_oversized))
while (oversized_block_ptr_traits::get(m_oversized))
{
oversized_block_descriptor_ptr alloc = m_oversized;
m_oversized = thrust::raw_reference_cast(*m_oversized).next;

oversized_block_descriptor desc =
thrust::raw_reference_cast(*alloc);

void_ptr p = static_cast<void_ptr>(
static_cast<char_ptr>(
static_cast<void_ptr>(alloc)
) - thrust::raw_reference_cast(*alloc).size
);
m_upstream->do_deallocate(p, thrust::raw_reference_cast(*alloc).size + sizeof(oversized_block_descriptor), thrust::raw_reference_cast(*alloc).alignment);
static_cast<char_ptr>(static_cast<void_ptr>(alloc)) -
desc.current_size);
m_upstream->do_deallocate(
p, desc.size + sizeof(oversized_block_descriptor),
desc.alignment);
}

m_cached_oversized = oversized_block_descriptor_ptr();
Expand All @@ -272,7 +278,7 @@ class unsynchronized_pool_resource final
{
oversized_block_descriptor_ptr ptr = m_cached_oversized;
oversized_block_descriptor_ptr * previous = &m_cached_oversized;
while (detail::pointer_traits<oversized_block_descriptor_ptr>::get(ptr))
while (oversized_block_ptr_traits::get(ptr))
{
oversized_block_descriptor desc = *ptr;
bool is_good = desc.size >= bytes && desc.alignment >= alignment;
Expand Down Expand Up @@ -305,23 +311,39 @@ class unsynchronized_pool_resource final
{
if (previous != &m_cached_oversized)
{
oversized_block_descriptor previous_desc = **previous;
previous_desc.next_cached = desc.next_cached;
**previous = previous_desc;
*previous = desc.next_cached;
}
else
{
m_cached_oversized = desc.next_cached;
}

desc.next_cached = oversized_block_descriptor_ptr();

auto ret =
static_cast<char_ptr>(static_cast<void_ptr>(ptr)) -
desc.size;

if (bytes != desc.size) {
desc.current_size = bytes;

ptr = static_cast<oversized_block_descriptor_ptr>(
static_cast<void_ptr>(ret + bytes));

if (oversized_block_ptr_traits::get(desc.prev)) {
thrust::raw_reference_cast(*desc.prev).next = ptr;
} else {
m_oversized = ptr;
}

if (oversized_block_ptr_traits::get(desc.next)) {
thrust::raw_reference_cast(*desc.next).prev = ptr;
}
}

*ptr = desc;

return static_cast<void_ptr>(
static_cast<char_ptr>(
static_cast<void_ptr>(ptr)
) - desc.size
);
return static_cast<void_ptr>(ret);
}

previous = &thrust::raw_reference_cast(*ptr).next_cached;
Expand All @@ -343,10 +365,11 @@ class unsynchronized_pool_resource final
desc.prev = oversized_block_descriptor_ptr();
desc.next = m_oversized;
desc.next_cached = oversized_block_descriptor_ptr();
desc.current_size = bytes;
*block = desc;
m_oversized = block;

if (detail::pointer_traits<oversized_block_descriptor_ptr>::get(desc.next))
if (oversized_block_ptr_traits::get(desc.next))
{
oversized_block_descriptor next = *desc.next;
next.prev = block;
Expand Down Expand Up @@ -439,7 +462,7 @@ class unsynchronized_pool_resource final
assert(detail::is_power_of_2(alignment));

// verify that the pointer is at least as aligned as claimed
assert(reinterpret_cast<detail::intmax_t>(detail::pointer_traits<void_ptr>::get(p)) % alignment == 0);
assert(reinterpret_cast<detail::intmax_t>(void_ptr_traits::get(p)) % alignment == 0);

// the deallocated block is oversized and/or overaligned
if (n > m_options.largest_block_size || alignment > m_options.alignment)
Expand All @@ -451,35 +474,44 @@ class unsynchronized_pool_resource final
);

oversized_block_descriptor desc = *block;
assert(desc.current_size == n);
assert(desc.alignment == alignment);

if (m_options.cache_oversized)
{
desc.next_cached = m_cached_oversized;
*block = desc;

if (desc.size != n) {
desc.current_size = desc.size;
block = static_cast<oversized_block_descriptor_ptr>(
static_cast<void_ptr>(static_cast<char_ptr>(p) +
desc.size));
if (oversized_block_ptr_traits::get(desc.prev)) {
thrust::raw_reference_cast(*desc.prev).next = block;
} else {
m_oversized = block;
}

if (oversized_block_ptr_traits::get(desc.next)) {
thrust::raw_reference_cast(*desc.next).prev = block;
}
}

m_cached_oversized = block;
*block = desc;

return;
}

if (!detail::pointer_traits<oversized_block_descriptor_ptr>::get(desc.prev))
{
assert(m_oversized == block);
if (oversized_block_ptr_traits::get(
desc.prev)) {
thrust::raw_reference_cast(*desc.prev).next = desc.next;
} else {
m_oversized = desc.next;
}
else
{
oversized_block_descriptor prev = *desc.prev;
assert(prev.next == block);
prev.next = desc.next;
*desc.prev = prev;
}

if (detail::pointer_traits<oversized_block_descriptor_ptr>::get(desc.next))
{
oversized_block_descriptor next = *desc.next;
assert(next.prev == block);
next.prev = desc.prev;
*desc.next = next;
if (oversized_block_ptr_traits::get(desc.next)) {
thrust::raw_reference_cast(*desc.next).prev = desc.prev;
}

m_upstream->do_deallocate(p, desc.size + sizeof(oversized_block_descriptor), desc.alignment);
Expand Down
Loading