diff --git a/README.md b/README.md index 91c5c577f..a49430150 100644 --- a/README.md +++ b/README.md @@ -732,3 +732,50 @@ This can be done in two ways: **Note:** This only configures Numba to use the current RMM resource for allocations. It does not initialize nor change the current resource, e.g., enabling a memory pool. See [here](#memoryresource-objects) for more information on changing the current memory resource. + +### Using RMM with PyTorch + +[PyTorch](https://pytorch.org/docs/stable/notes/cuda.html) can use RMM +for memory allocation. For example, to configure PyTorch to use an +RMM-managed pool: + +```python +import rmm +import torch + +rmm.reinitialize(pool_allocator=True) +torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator) +``` + +PyTorch and RMM will now share the same memory pool. + +You can, of course, use a custom memory resource with PyTorch as well: + +```python +import rmm +import torch + +# note that you can configure PyTorch to use RMM either before or +# after changing RMM's memory resource. PyTorch will use whatever +# memory resource is configured to be the "current" memory resource at +# the time of allocation. +torch.cuda.change_current_allocator(rmm.rmm_torch_allocator) + +# configure RMM to use a managed memory resource, wrapped with a +# statistics resource adaptor that can report information about the +# amount of memory allocated: +mr = rmm.mr.StatisticsResourceAdaptor(rmm.mr.ManagedMemoryResource()) +rmm.mr.set_current_device_resource(mr) + +x = torch.tensor([1, 2]).cuda() + +# the memory resource reports information about PyTorch allocations: +mr.allocation_counts +Out[6]: +{'current_bytes': 16, + 'current_count': 1, + 'peak_bytes': 16, + 'peak_count': 1, + 'total_bytes': 16, + 'total_count': 1} +``` diff --git a/python/rmm/__init__.py b/python/rmm/__init__.py index acdeb93a8..9fb13fe73 100644 --- a/python/rmm/__init__.py +++ b/python/rmm/__init__.py @@ -25,6 +25,7 @@ register_reinitialize_hook, reinitialize, rmm_cupy_allocator, + rmm_torch_allocator, unregister_reinitialize_hook, ) diff --git a/python/rmm/_lib/CMakeLists.txt b/python/rmm/_lib/CMakeLists.txt index 44f4513b2..9e90d7e99 100644 --- a/python/rmm/_lib/CMakeLists.txt +++ b/python/rmm/_lib/CMakeLists.txt @@ -12,7 +12,8 @@ # the License. # ============================================================================= -set(cython_sources device_buffer.pyx lib.pyx memory_resource.pyx cuda_stream.pyx) +set(cython_sources device_buffer.pyx lib.pyx memory_resource.pyx cuda_stream.pyx + torch_allocator.pyx) set(linked_libraries rmm::rmm) # Build all of the Cython targets diff --git a/python/rmm/_lib/memory_resource.pxd b/python/rmm/_lib/memory_resource.pxd index 387d39866..5bb3746bc 100644 --- a/python/rmm/_lib/memory_resource.pxd +++ b/python/rmm/_lib/memory_resource.pxd @@ -17,12 +17,20 @@ from libcpp.memory cimport shared_ptr from libcpp.string cimport string from libcpp.vector cimport vector +from rmm._lib.cuda_stream_view cimport cuda_stream_view + cdef extern from "rmm/mr/device/device_memory_resource.hpp" \ namespace "rmm::mr" nogil: cdef cppclass device_memory_resource: void* allocate(size_t bytes) except + + void* allocate(size_t bytes, cuda_stream_view stream) except + void deallocate(void* ptr, size_t bytes) except + + void deallocate( + void* ptr, + size_t bytes, + cuda_stream_view stream + ) except + cdef class DeviceMemoryResource: cdef shared_ptr[device_memory_resource] c_obj diff --git a/python/rmm/_lib/memory_resource.pyx b/python/rmm/_lib/memory_resource.pyx index 854e14d8c..774db374a 100644 --- a/python/rmm/_lib/memory_resource.pyx +++ b/python/rmm/_lib/memory_resource.pyx @@ -29,6 +29,10 @@ from cuda.cudart import cudaError_t from rmm._cuda.gpu import CUDARuntimeError, getDevice, setDevice from rmm._lib.cuda_stream_view cimport cuda_stream_view +from rmm._lib.per_device_resource cimport ( + cuda_device_id, + set_per_device_resource as cpp_set_per_device_resource, +) # Transparent handle of a C++ exception ctypedef pair[int, string] CppExcept @@ -206,29 +210,6 @@ cdef extern from "rmm/mr/device/failure_callback_resource_adaptor.hpp" \ ) except + -cdef extern from "rmm/mr/device/per_device_resource.hpp" namespace "rmm" nogil: - - cdef cppclass cuda_device_id: - ctypedef int value_type - - cuda_device_id(value_type id) - - value_type value() - - cdef device_memory_resource* _set_current_device_resource \ - "rmm::mr::set_current_device_resource" (device_memory_resource* new_mr) - cdef device_memory_resource* _get_current_device_resource \ - "rmm::mr::get_current_device_resource" () - - cdef device_memory_resource* _set_per_device_resource \ - "rmm::mr::set_per_device_resource" ( - cuda_device_id id, - device_memory_resource* new_mr - ) - cdef device_memory_resource* _get_per_device_resource \ - "rmm::mr::get_per_device_resource"(cuda_device_id id) - - cdef class DeviceMemoryResource: cdef device_memory_resource* get_mr(self): @@ -967,7 +948,7 @@ cpdef set_per_device_resource(int device, DeviceMemoryResource mr): cdef unique_ptr[cuda_device_id] device_id = \ make_unique[cuda_device_id](device) - _set_per_device_resource(deref(device_id), mr.get_mr()) + cpp_set_per_device_resource(deref(device_id), mr.get_mr()) cpdef set_current_device_resource(DeviceMemoryResource mr): diff --git a/python/rmm/_lib/per_device_resource.pxd b/python/rmm/_lib/per_device_resource.pxd new file mode 100644 index 000000000..c33217622 --- /dev/null +++ b/python/rmm/_lib/per_device_resource.pxd @@ -0,0 +1,23 @@ +from rmm._lib.memory_resource cimport device_memory_resource + + +cdef extern from "rmm/mr/device/per_device_resource.hpp" namespace "rmm" nogil: + cdef cppclass cuda_device_id: + ctypedef int value_type + + cuda_device_id(value_type id) + + value_type value() + +cdef extern from "rmm/mr/device/per_device_resource.hpp" \ + namespace "rmm::mr" nogil: + cdef device_memory_resource* set_current_device_resource( + device_memory_resource* new_mr + ) + cdef device_memory_resource* get_current_device_resource() + cdef device_memory_resource* set_per_device_resource( + cuda_device_id id, device_memory_resource* new_mr + ) + cdef device_memory_resource* get_per_device_resource ( + cuda_device_id id + ) diff --git a/python/rmm/_lib/torch_allocator.pyx b/python/rmm/_lib/torch_allocator.pyx new file mode 100644 index 000000000..12dc9fe11 --- /dev/null +++ b/python/rmm/_lib/torch_allocator.pyx @@ -0,0 +1,24 @@ +from cuda.ccudart cimport cudaStream_t + +from rmm._lib.cuda_stream_view cimport cuda_stream_view +from rmm._lib.memory_resource cimport device_memory_resource +from rmm._lib.per_device_resource cimport get_current_device_resource + + +cdef public void* allocate( + ssize_t size, int device, void* stream +) except * with gil: + cdef device_memory_resource* mr = get_current_device_resource() + cdef cuda_stream_view stream_view = cuda_stream_view( + (stream) + ) + return mr[0].allocate(size, stream_view) + +cdef public void deallocate( + void* ptr, ssize_t size, void* stream +) except * with gil: + cdef device_memory_resource* mr = get_current_device_resource() + cdef cuda_stream_view stream_view = cuda_stream_view( + (stream) + ) + mr[0].deallocate(ptr, size, stream_view) diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index 398d83de3..cae9971dc 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -237,6 +237,21 @@ def rmm_cupy_allocator(nbytes): return ptr +try: + from torch.cuda.memory import CUDAPluggableAllocator +except ImportError: + rmm_torch_allocator = None +else: + import rmm._lib.torch_allocator + + _alloc_free_lib_path = rmm._lib.torch_allocator.__file__ + rmm_torch_allocator = CUDAPluggableAllocator( + _alloc_free_lib_path, + alloc_fn_name="allocate", + free_fn_name="deallocate", + ) + + def register_reinitialize_hook(func, *args, **kwargs): """ Add a function to the list of functions ("hooks") that will be diff --git a/python/rmm/tests/conftest.py b/python/rmm/tests/conftest.py new file mode 100644 index 000000000..5fad81c79 --- /dev/null +++ b/python/rmm/tests/conftest.py @@ -0,0 +1,21 @@ +import pytest + +import rmm + + +@pytest.fixture(scope="function", autouse=True) +def rmm_auto_reinitialize(): + # Run the test + yield + + # Automatically reinitialize the current memory resource after running each + # test + + rmm.reinitialize() + + +@pytest.fixture +def stats_mr(): + mr = rmm.mr.StatisticsResourceAdaptor(rmm.mr.CudaMemoryResource()) + rmm.mr.set_current_device_resource(mr) + return mr diff --git a/python/rmm/tests/test_rmm.py b/python/rmm/tests/test_rmm.py index 89d67a9c4..f79c60b43 100644 --- a/python/rmm/tests/test_rmm.py +++ b/python/rmm/tests/test_rmm.py @@ -42,17 +42,6 @@ ) -@pytest.fixture(scope="function", autouse=True) -def rmm_auto_reinitialize(): - - # Run the test - yield - - # Automatically reinitialize the current memory resource after running each - # test - rmm.reinitialize() - - def array_tester(dtype, nelem, alloc): # data h_in = np.full(nelem, 3.2, dtype) @@ -604,20 +593,14 @@ def test_cuda_async_memory_resource_threshold(nelem, alloc): array_tester("u1", 2 * nelem, alloc) # should trigger release -def test_statistics_resource_adaptor(): - - cuda_mr = rmm.mr.CudaMemoryResource() - - mr = rmm.mr.StatisticsResourceAdaptor(cuda_mr) - - rmm.mr.set_current_device_resource(mr) +def test_statistics_resource_adaptor(stats_mr): buffers = [rmm.DeviceBuffer(size=1000) for _ in range(10)] for i in range(9, 0, -2): del buffers[i] - assert mr.allocation_counts == { + assert stats_mr.allocation_counts == { "current_bytes": 5000, "current_count": 5, "peak_bytes": 10000, @@ -627,7 +610,7 @@ def test_statistics_resource_adaptor(): } # Push a new Tracking adaptor - mr2 = rmm.mr.StatisticsResourceAdaptor(mr) + mr2 = rmm.mr.StatisticsResourceAdaptor(stats_mr) rmm.mr.set_current_device_resource(mr2) for _ in range(2): @@ -641,7 +624,7 @@ def test_statistics_resource_adaptor(): "total_bytes": 2000, "total_count": 2, } - assert mr.allocation_counts == { + assert stats_mr.allocation_counts == { "current_bytes": 7000, "current_count": 7, "peak_bytes": 10000, @@ -661,7 +644,7 @@ def test_statistics_resource_adaptor(): "total_bytes": 2000, "total_count": 2, } - assert mr.allocation_counts == { + assert stats_mr.allocation_counts == { "current_bytes": 0, "current_count": 0, "peak_bytes": 10000, @@ -669,10 +652,10 @@ def test_statistics_resource_adaptor(): "total_bytes": 12000, "total_count": 12, } + gc.collect() def test_tracking_resource_adaptor(): - cuda_mr = rmm.mr.CudaMemoryResource() mr = rmm.mr.TrackingResourceAdaptor(cuda_mr, capture_stacks=True) diff --git a/python/rmm/tests/test_rmm_pytorch.py b/python/rmm/tests/test_rmm_pytorch.py new file mode 100644 index 000000000..eaa40c0ed --- /dev/null +++ b/python/rmm/tests/test_rmm_pytorch.py @@ -0,0 +1,37 @@ +import gc + +import pytest + +import rmm + +torch = pytest.importorskip("torch") + + +@pytest.fixture(scope="session") +def torch_allocator(): + try: + from torch.cuda.memory import change_current_allocator + except ImportError: + pytest.skip("pytorch pluggable allocator not available") + change_current_allocator(rmm.rmm_torch_allocator) + + +def test_rmm_torch_allocator(torch_allocator, stats_mr): + assert stats_mr.allocation_counts["current_bytes"] == 0 + x = torch.tensor([1, 2]).cuda() + assert stats_mr.allocation_counts["current_bytes"] > 0 + del x + gc.collect() + assert stats_mr.allocation_counts["current_bytes"] == 0 + + +def test_rmm_torch_allocator_using_stream(torch_allocator, stats_mr): + assert stats_mr.allocation_counts["current_bytes"] == 0 + s = torch.cuda.Stream() + with torch.cuda.stream(s): + x = torch.tensor([1, 2]).cuda() + torch.cuda.current_stream().wait_stream(s) + assert stats_mr.allocation_counts["current_bytes"] > 0 + del x + gc.collect() + assert stats_mr.allocation_counts["current_bytes"] == 0