-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Unexpected memory usage on GPU0 #1405
Comments
This is almost correct. However, when creating a memory resource, you should do so with the target device active. That is, you must call This can be done like so: import rmm
device = (int(os.environ['LOCAL_RANK']))
rmm._cuda.gpu.setDevice(device)
pool = rmm.mr.PoolMemoryResource(...)
rmm.mr.set_per_device_resource(device, pool) Since this is such a common pattern, the top-level import rmm
device = int(os.environ["LOCAL_RANK"])
rmm.reinitialize(devices=device, pool_allocator=True, initial_pool_size=...) This doesn't have quite as much flexibility on the set up of the allocator, but if you just need a pool on top of a cuda memory resource then it works fine. We could add an interface whereby you provide a zero-argument callback to construct the pool (and |
@wence- Thanks for your reply!
|
Hmm, we fixed some bugs around stream ordered memory resources that will Can you provide a complete example script to run and I will try and reproduce locally. |
Let me try the 23.12 |
I tried this trivial code: import os
import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch
device = (int(os.environ['LOCAL_RANK']))
torch.cuda.change_current_allocator(rmm_torch_allocator)
rmm._cuda.gpu.setDevice(device)
pool = rmm.mr.PoolMemoryResource(
rmm.mr.CudaMemoryResource(),
initial_pool_size=2**30,
)
rmm.mr.set_per_device_resource(device, pool)
print(torch.zeros(2, 3, device=f"cuda:{device}")) When I run with |
Emmm, I tried this code and got some interesting results. I run the sample code with
The core file:
I tried multiple times and noticed that every time it could only print the tensors on device 0,1,2. So, I tried with I tried with RMM v23.12.00, Python 3.10 and PyTorch 2.1.1 |
Thanks, I'll try and reproduce on a system with more than two GPUs. |
Is it possible that the active device could be changing before the |
I don’t think the trivial code nor PyTorch would do so. That could not explain why less than 4 GPUs worked. I modified the code into
The output
It seems that the tensors on devices 3,4,5,6,7 has been deallocated before |
I was able to reproduce running with four GPUs, I have yet to figure out what is going on. Debugging under gdb is difficult here because torchrun is running things in processes, but. If we run in gdb with Next step is to build RMM in debug mode so I have some symbols to inspect. This is what I have right now to debug, note I only need to allocate things on a single device:
So my suspicion is that torch shuffling cuda devices out from under us in a bad way. |
Thanks so much for debugging, @wence- . |
OK, I have the culprit. The signature we offer for the plug in allocation functions is:
Which was the original signature for the pluggable allocators when we introduced this in #1168, introduced in pytorch in pytorch/pytorch#86786 But soon after, in pytorch/pytorch#91398 the signatures were changed to:
Note the change to also accept the device in the deallocate function. So we're getting The fix is the fix the signature in RMM (I will prepare a patch). |
Since pytorch/pytorch#91398, the signature of the pluggable allocate and deallocate functions must accept the device id. The current version only accepts a device id for allocate, which means that when using a stream ordered allocator with devices other than device zero, we pass an invalid stream into the deallocation function. To fix this, adapt the signature to match the one pytorch expects. Now, since we have the device available during allocation and deallocation, we would like to use that device to obtain the appropriate memory resource. Unfortunately, since RMM's cuda_device_id does not have a nullary constructor, we can't use it in Cython without some hacky workarounds. However, since we don't actually need to build a Python module, but rather just a single shared library that offers two extern "C" functions, let's just write our allocator hooks directly in C++. - Closes rapidsai#1405
Here is a minimal diff that will allow your code to run: diff --git a/python/rmm/_lib/torch_allocator.pyx b/python/rmm/_lib/torch_allocator.pyx
index 12dc9fe1..2b11028c 100644
--- a/python/rmm/_lib/torch_allocator.pyx
+++ b/python/rmm/_lib/torch_allocator.pyx
@@ -15,7 +15,7 @@ cdef public void* allocate(
return mr[0].allocate(size, stream_view)
cdef public void deallocate(
- void* ptr, ssize_t size, void* stream
+ void* ptr, ssize_t size, int device, void* stream
) except * with gil:
cdef device_memory_resource* mr = get_current_device_resource()
cdef cuda_stream_view stream_view = cuda_stream_view( However, in #1407 I am trying to do a better thing, which is to use the memory resource associated with the device we are being passed, rather than just assuming that |
The deallocation function now also takes the device id. Since both halves of the pair now receive the device on which to perform the (de)allocation, we switch from using get_current_device_resource to using the (more correct) get_per_device_resource. This necessitates a workaround in Cython: rmm::cuda_device_id has no nullary constructor, and so cannot be stack-allocated the way Cython transpiles code. Instead perform a heap allocation and then delete it. - Closes rapidsai#1405
Since pytorch/pytorch#91398, the signature of the pluggable allocate and deallocate functions must accept the device id. The current version only accepts a device id for allocate, which means that when using a stream ordered allocator with devices other than device zero, we pass an invalid stream into the deallocation function. To fix this, adapt the signature to match the one pytorch expects. Now, since we have the device available during allocation and deallocation, we would like to use that device to obtain the appropriate memory resource. Unfortunately, since RMM's cuda_device_id does not have a nullary constructor, we can't use it in Cython without some hacky workarounds. However, since we don't actually need to build a Python module, but rather just a single shared library that offers two extern "C" functions, let's just write our allocator hooks directly in C++. - Closes rapidsai#1405
Can you try if the code in #1408 works for you @li-yi-dong? |
Since pytorch/pytorch#91398, the signature of the pluggable allocate and deallocate functions must accept the device id. The current version only accepts a device id for allocate, which means that when using a stream ordered allocator with devices other than device zero, we pass an invalid stream into the deallocation function. To fix this, adapt the signature to match the one pytorch expects. Now, since we have the device available during allocation and deallocation, we would like to use that device to obtain the appropriate memory resource. Unfortunately, since RMM's cuda_device_id does not have a nullary constructor, we can't use it in Cython without some hacky workarounds. However, since we don't actually need to build a Python module, but rather just a single shared library that offers two extern "C" functions, let's just write our allocator hooks directly in C++. - Closes rapidsai#1405
Since pytorch/pytorch#91398, the signature of the pluggable allocate and deallocate functions must accept the device id. The current version only accepts a device id for allocate, which means that when using a stream ordered allocator with devices other than device zero, we pass an invalid stream into the deallocation function. To fix this, adapt the signature to match the one pytorch expects. Now, since we have the device available during allocation and deallocation, we would like to use that device to obtain the appropriate memory resource. Unfortunately, since RMM's cuda_device_id does not have a nullary constructor, we can't use it in Cython without some hacky workarounds. However, since we don't actually need to build a Python module, but rather just a single shared library that offers two extern "C" functions, let's just write our allocator hooks directly in C++. - Closes rapidsai#1405
Since pytorch/pytorch#91398, the signature of the pluggable allocate and deallocate functions must accept the device id. The current version only accepts a device id for allocate, which means that when using a stream ordered allocator with devices other than device zero, we pass an invalid stream into the deallocation function. To fix this, adapt the signature to match the one pytorch expects. Now, since we have the device available during allocation and deallocation, we would like to use that device to obtain the appropriate memory resource. Unfortunately, since RMM's cuda_device_id does not have a nullary constructor, we can't use it in Cython without some hacky workarounds. However, since we don't actually need to build a Python module, but rather just a single shared library that offers two extern "C" functions, let's just write our allocator hooks directly in C++. - Closes rapidsai#1405
I works pretty smooth with my task. And the RMM really outperforms the PyTorch caching allocator in terms of fragmentation. |
Great, thanks! In the end we are going with the code in #1407 which I hope very much also works identically, if you could confirm that would be wonderful. |
It works fine. |
Since pytorch/pytorch#91398, the signature of the pluggable allocate and deallocate functions must accept the device id. The current version only accepts a device id for allocate, which means that when using a stream ordered allocator with devices other than device zero, we pass an invalid stream into the deallocation function. To fix this, adapt the signature to match the one pytorch expects. Now, since we have the device available during allocation and deallocation, we would like to use that device to obtain the appropriate memory resource. Unfortunately, since RMM's cuda_device_id does not have a nullary constructor, we can't use it in Cython without some hacky workarounds. However, since we don't actually need to build a Python module, but rather just a single shared library that offers two extern "C" functions, let's just write our allocator hooks directly in C++. - Closes #1405 Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Mark Harris (https://github.com/harrism) - Vyas Ramasubramani (https://github.com/vyasr) URL: #1407
Describe the bug
I tried to use RMM with PyTorch. I launch my task with torchrun and set the rmm.mr for each device at the very beginning.
But each process occupies a chunk of memory on GPU0 like
Steps/Code to reproduce bug
Expected behavior
I expected each process launched by torchrun only uses the memory on the GPU assigned by
LOCAL_RANK
Environment details (please complete the following information):
I'm using RMM v23.10.00
Here is the output of the print_env.sh
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: