diff --git a/python/rmm/_lib/memory_resource.pyx b/python/rmm/_lib/memory_resource.pyx index a20d481e0..afeeb3557 100644 --- a/python/rmm/_lib/memory_resource.pyx +++ b/python/rmm/_lib/memory_resource.pyx @@ -16,6 +16,7 @@ import os import warnings from collections import defaultdict +cimport cython from cython.operator cimport dereference as deref from libc.stdint cimport int8_t, int64_t, uintptr_t from libcpp cimport bool @@ -247,6 +248,8 @@ cdef class DeviceMemoryResource: self.c_obj.get().deallocate((ptr), nbytes) +# See the note about `no_gc_clear` in `device_buffer.pyx`. +@cython.no_gc_clear cdef class UpstreamResourceAdaptor(DeviceMemoryResource): def __cinit__(self, DeviceMemoryResource upstream_mr, *args, **kwargs): diff --git a/python/rmm/tests/test_rmm.py b/python/rmm/tests/test_rmm.py index 931ff5336..89d67a9c4 100644 --- a/python/rmm/tests/test_rmm.py +++ b/python/rmm/tests/test_rmm.py @@ -742,6 +742,13 @@ def callback(nbytes: int) -> bool: def test_dev_buf_circle_ref_dealloc(): + # This test creates a reference cycle containing a `DeviceBuffer` + # and ensures that the garbage collector does not clear it, i.e., + # that the GC does not remove all references to other Python + # objects from it. The `DeviceBuffer` needs to keep its reference + # to the `DeviceMemoryResource` that was used to create it in + # order to be cleaned up properly. See GH #931. + rmm.mr.set_current_device_resource(rmm.mr.CudaMemoryResource()) dbuf1 = rmm.DeviceBuffer(size=1_000_000) @@ -751,17 +758,27 @@ def test_dev_buf_circle_ref_dealloc(): l1.append(l1) # due to the reference cycle, the device buffer doesn't actually get - # cleaned up until later, when we invoke `gc.collect()`: + # cleaned up until after `gc.collect()` is called. del dbuf1, l1 rmm.mr.set_current_device_resource(rmm.mr.CudaMemoryResource()) - # by now, the only remaining reference to the *original* memory - # resource should be in `dbuf1`. However, the cyclic garbage collector - # will eliminate that reference when it clears the object via its - # `tp_clear` method. Later, when `tp_dealloc` attemps to actually - # deallocate `dbuf1` (which needs the MR alive), a segfault occurs. + # test that after the call to `gc.collect()`, the `DeviceBuffer` + # is deallocated successfully (i.e., without a segfault). + gc.collect() + + +def test_upstream_mr_circle_ref_dealloc(): + # This test is just like the one above, except it tests that + # instances of `UpstreamResourceAdaptor` (such as + # `PoolMemoryResource`) are not cleared by the GC. + rmm.mr.set_current_device_resource(rmm.mr.CudaMemoryResource()) + mr = rmm.mr.PoolMemoryResource(rmm.mr.get_current_device_resource()) + l1 = [mr] + l1.append(l1) + del mr, l1 + rmm.mr.set_current_device_resource(rmm.mr.CudaMemoryResource()) gc.collect()