Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Unexpected memory usage on GPU0 #1405

Closed
li-yi-dong opened this issue Dec 7, 2023 · 17 comments · Fixed by #1407
Closed

[BUG] Unexpected memory usage on GPU0 #1405

li-yi-dong opened this issue Dec 7, 2023 · 17 comments · Fixed by #1407
Assignees
Labels
bug Something isn't working doc Documentation question Further information is requested

Comments

@li-yi-dong
Copy link

Describe the bug
I tried to use RMM with PyTorch. I launch my task with torchrun and set the rmm.mr for each device at the very beginning.

torch.cuda.change_current_allocator(rmm_torch_allocator)

pool = rmm.mr.PoolMemoryResource(
    rmm.mr.CudaMemoryResource(),
    initial_pool_size=2**30,
)
device = (int(os.environ['LOCAL_RANK']))

rmm.mr.set_per_device_resource(device, pool)

But each process occupies a chunk of memory on GPU0 like
image

Steps/Code to reproduce bug

Expected behavior
I expected each process launched by torchrun only uses the memory on the GPU assigned by LOCAL_RANK

Environment details (please complete the following information):
I'm using RMM v23.10.00
Here is the output of the print_env.sh

***GPU Information***
Thu Dec  7 17:18:59 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H800                    On  | 00000000:08:00.0 Off |                    0 |
| N/A   31C    P0              70W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H800                    On  | 00000000:7E:00.0 Off |                    0 |
| N/A   26C    P0              69W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H800                    On  | 00000000:A2:00.0 Off |                    0 |
| N/A   33C    P0              74W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H800                    On  | 00000000:C6:00.0 Off |                    0 |
| N/A   29C    P0              69W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H800                    On  | 00000001:09:00.0 Off |                    0 |
| N/A   26C    P0              71W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H800                    On  | 00000001:7F:00.0 Off |                    0 |
| N/A   30C    P0              71W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H800                    On  | 00000001:A3:00.0 Off |                    0 |
| N/A   30C    P0              71W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H800                    On  | 00000001:C7:00.0 Off |                    0 |
| N/A   35C    P0              73W / 700W |      7MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

***CPU***
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                192
On-line CPU(s) list:   0-183
Off-line CPU(s) list:  184-191
Thread(s) per core:    1
Core(s) per socket:    48
Socket(s):             2
NUMA node(s):          2
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 143
Model name:            Intel(R) Xeon(R) Platinum 8469C
Stepping:              8
CPU MHz:               3100.000
CPU max MHz:           3800.0000
CPU min MHz:           800.0000
BogoMIPS:              5200.00
Virtualization:        VT-x
L1d cache:             48K
L1i cache:             32K
L2 cache:              2048K
L3 cache:              99840K
NUMA node0 CPU(s):     0-47,96-143
NUMA node1 CPU(s):     48-95,144-191
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm uintr md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities

***CMake***
/usr/local/cmake-3.22.0-linux-x86_64/bin/cmake
cmake version 3.22.0

CMake suite maintained and supported by Kitware (kitware.com/cmake).

***g++***
/usr/bin/g++
g++ (GCC) 8.3.1 20190311 (Red Hat 8.3.1-3)
Copyright (C) 2018 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.


***nvcc***
/usr/local/cuda/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Fri_Sep__8_19:17:24_PDT_2023
Cuda compilation tools, release 12.3, V12.3.52
Build cuda_12.3.r12.3/compiler.33281558_0

***Python***
/opt/conda/envs/python3.8/bin/python
Python 3.8.13

***Environment Variables***
PATH                            : /opt/conda/envs/python3.8/bin:/opt/conda/condabin:PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/conda/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/cmake-3.22.0-linux-x86_64/bin:/usr/local/ninja:/home/admin/.local/bin:PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/conda/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/cmake-3.22.0-linux-x86_64/bin:/usr/local/ninja:/home/admin/.local/bin:/usr/local/sbin:PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/conda/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/cmake-3.22.0-linux-x86_64/bin:/usr/local/ninja:/home/admin/.local/bin:/usr/X11R6/bin:/opt/satools
LD_LIBRARY_PATH                 : /usr/local/lib64:/lib64:/usr/local/gcc75/lib:/usr/local/gcc75/lib64::/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/usr/lib64:/pu:/opt/taobao/java/jre/lib/amd64/server:/apsara/alicpp/built/gcc-4.9.2/glog-0.3.4/lib:/apsara/alicpp/built/gcc-4.9.2/gflags-2.1.2/lib:/apsara/alicpp/built/gcc-4.9.2/protobuf-2.4.1.ali/lib:/apsara/alicpp/built/gcc-4.9.2/odps-cryptography-1.0.0/lib:/apsara/alicpp/built/gcc-4.9.2/boost-1.58.0.fix.thread/lib:/apsara/alicpp/built/gcc-4.9.2/openssl-1.0.2a/lib:/apsara/alicpp/built/gcc-4.9.2/mysql-connector-c-6.1.6/lib:/apsara/alicpp/built/gcc-4.9.2/arrow-0.16.0/lib64:/apsara/alicpp/built/gcc-4.9.2/bzip2-1.0.6/lib64:/apsara/alicpp/built/gcc-4.9.2/zstd-1.4.4/lib:/apsara/alicpp/built/gcc-4.9.2/libevent-2.0.22.stable/lib64:/worker:/worker/lib:/opt/conda/envs/python3.8.13/lib:/usr/local/hadoop/hadoop/lib/native
NUMBAPRO_NVVM                   :
NUMBAPRO_LIBDEVICE              :
CONDA_PREFIX                    : /opt/conda/envs/python3.8
PYTHON_PATH                     :

***conda packages***
/opt/conda/condabin/conda
# packages in environment at /opt/conda/envs/python3.8:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
absl-py                   2.0.0                    pypi_0    pypi
acm-sdk-python            0.4.11                   pypi_0    pypi
ai-scheduler              0.2-89a0e10133697087a464f6bc434bfb9f8b639eb5          pypi_0    pypi
aliyun-python-sdk-core    2.14.0                   pypi_0    pypi
aliyun-python-sdk-kms     2.16.2                   pypi_0    pypi
annotated-types           0.6.0                    pypi_0    pypi
apex                      0.1                      pypi_0    pypi
astunparse                1.6.3                    pypi_0    pypi
async-timeout             4.0.3                    pypi_0    pypi
av                        8.1.0                    pypi_0    pypi
bitsandbytes              0.41.0                   pypi_0    pypi
ca-certificates           2023.08.22           h06a4308_0
cachetools                5.3.2                    pypi_0    pypi
certifi                   2023.11.17               pypi_0    pypi
cffi                      1.16.0                   pypi_0    pypi
chardet                   3.0.4                    pypi_0    pypi
charset-normalizer        3.3.2                    pypi_0    pypi
common-io                 0.8.3                    pypi_0    pypi
crcmod                    1.7                      pypi_0    pypi
cryptography              41.0.7                   pypi_0    pypi
cuda-python               11.8.2                   pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
cython                    3.0.6                    pypi_0    pypi
deepspeed                 0.8.2                    pypi_0    pypi
docopt                    0.6.2                    pypi_0    pypi
easydict                  1.11                     pypi_0    pypi
einops                    0.7.0                    pypi_0    pypi
filelock                  3.13.1                   pypi_0    pypi
flash-attn                2.2.1                    pypi_0    pypi
fsspec                    2023.12.0                pypi_0    pypi
future                    0.18.2                   pypi_0    pypi
google-auth               2.24.0                   pypi_0    pypi
google-auth-oauthlib      1.0.0                    pypi_0    pypi
gputil                    1.4.0                    pypi_0    pypi
grpcio                    1.59.3                   pypi_0    pypi
hdfs                      2.7.3                    pypi_0    pypi
hjson                     3.1.0                    pypi_0    pypi
idna                      2.8                      pypi_0    pypi
importlib-metadata        6.8.0                    pypi_0    pypi
intel-openmp              2024.0.0                 pypi_0    pypi
jieba                     0.42.1                   pypi_0    pypi
jinja2                    3.1.2                    pypi_0    pypi
jmespath                  0.10.0                   pypi_0    pypi
joblib                    1.3.2                    pypi_0    pypi
kazoo                     2.9.0                    pypi_0    pypi
kiwisolver                1.4.5                    pypi_0    pypi
kmontitor-client          0.0.0                    pypi_0    pypi
lake-py-lib               0.1.7.ziying             pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1
libffi                    3.3                  he6710b0_2
libgcc-ng                 9.1.0                hdf63c60_0
libstdcxx-ng              9.1.0                hdf63c60_0
lightning-utilities       0.10.0                   pypi_0    pypi
llvmlite                  0.41.1                   pypi_0    pypi
lmdb                      0.94                     pypi_0    pypi
lru-dict                  1.3.0                    pypi_0    pypi
magma-cuda121             2.6.1                         1    pytorch
markdown                  3.5.1                    pypi_0    pypi
markupsafe                2.1.3                    pypi_0    pypi
matplotlib                3.3.4                    pypi_0    pypi
mdl                       0.2                      pypi_0    pypi
mkl                       2024.0.0                 pypi_0    pypi
mkl-include               2024.0.0                 pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
ncurses                   6.3                  h7f8727e_2
nebula-mos-python-sdk     0.3.16                   pypi_0    pypi
nebula-py-pangu-early-test 0.0.41                   pypi_0    pypi
networkx                  3.1                      pypi_0    pypi
ninja                     1.11.1.1                 pypi_0    pypi
numba                     0.58.1                   pypi_0    pypi
numpy                     1.24.4                   pypi_0    pypi
nvidia-ml-py3             7.352.0                  pypi_0    pypi
oauthlib                  3.2.2                    pypi_0    pypi
opencv-python             4.5.4.60                 pypi_0    pypi
openssl                   1.1.1w               h7f8727e_0
oss2                      2.18.3                   pypi_0    pypi
packaging                 23.2                     pypi_0    pypi
pandas                    1.1.5                    pypi_0    pypi
pillow                    8.4.0                    pypi_0    pypi
pip                       23.3.1           py38h06a4308_0
protobuf                  3.20.1                   pypi_0    pypi
psutil                    5.9.5                    pypi_0    pypi
py-cpuinfo                9.0.0                    pypi_0    pypi
py-spy                    0.3.14                   pypi_0    pypi
pyarrow                   14.0.1                   pypi_0    pypi
pyasn1                    0.5.1                    pypi_0    pypi
pyasn1-modules            0.3.0                    pypi_0    pypi
pybind11                  2.11.1                   pypi_0    pypi
pycparser                 2.21                     pypi_0    pypi
pycryptodome              3.19.0                   pypi_0    pypi
pydantic                  1.10.9                   pypi_0    pypi
pydantic-core             2.14.5                   pypi_0    pypi
pydicom                   1.2.2                    pypi_0    pypi
pykmonitor                1.0                      pypi_0    pypi
pynvml                    11.5.0                   pypi_0    pypi
pyodps-int                0.11.5                   pypi_0    pypi
pyparsing                 3.1.1                    pypi_0    pypi
pytest-runner             6.0.0                    pypi_0    pypi
python                    3.8.13               h12debd9_0
python-dateutil           2.8.2                    pypi_0    pypi
pytz                      2023.3.post1             pypi_0    pypi
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.1.2                h7f8727e_1
redis                     5.0.1                    pypi_0    pypi
regex                     2023.10.3                pypi_0    pypi
requests                  2.22.0                   pypi_0    pypi
requests-oauthlib         1.3.1                    pypi_0    pypi
retrying                  1.3.4                    pypi_0    pypi
rmm                       23.10.0                  pypi_0    pypi
rsa                       4.9                      pypi_0    pypi
scikit-learn              0.24.2                   pypi_0    pypi
scipy                     1.7.3                    pypi_0    pypi
sentencepiece             0.1.96                   pypi_0    pypi
setuptools                68.0.0           py38h06a4308_0
simplejson                3.17.6                   pypi_0    pypi
six                       1.16.0                   pypi_0    pypi
sklearn                   0.0.post12               pypi_0    pypi
sqlite                    3.38.5               hc218d9a_0
sympy                     1.12                     pypi_0    pypi
tbb                       2021.11.0                pypi_0    pypi
tensorboard               2.14.0                   pypi_0    pypi
tensorboard-data-server   0.7.2                    pypi_0    pypi
thop-statistics           0.1.1-2303141613          pypi_0    pypi
threadpoolctl             3.2.0                    pypi_0    pypi
thrift                    0.16.0                   pypi_0    pypi
tk                        8.6.12               h1ccaba5_0
torch                     2.1.0+cu121              pypi_0    pypi
torchaudio                2.1.0+cu121              pypi_0    pypi
torchmetrics              1.2.1                    pypi_0    pypi
torchvision               0.16.0+cu121             pypi_0    pypi
tornado                   6.1                      pypi_0    pypi
tqdm                      4.66.1                   pypi_0    pypi
transformer-engine        1.0.0+66d91d5            pypi_0    pypi
transitions               0.9.0                    pypi_0    pypi
triton                    2.1.0                    pypi_0    pypi
typing-extensions         4.8.0                    pypi_0    pypi
urllib3                   1.25.11                  pypi_0    pypi
werkzeug                  3.0.1                    pypi_0    pypi
wheel                     0.41.2           py38h06a4308_0
xformers                  0.0.22.post4             pypi_0    pypi
xz                        5.2.5                h7f8727e_1
zipp                      3.17.0                   pypi_0    pypi
zlib                      1.2.12               h7f8727e_2

Additional context
Add any other context about the problem here.

@li-yi-dong li-yi-dong added ? - Needs Triage Need team to review and classify bug Something isn't working labels Dec 7, 2023
@wence-
Copy link
Contributor

wence- commented Dec 7, 2023

torch.cuda.change_current_allocator(rmm_torch_allocator)

pool = rmm.mr.PoolMemoryResource(
    rmm.mr.CudaMemoryResource(),
    initial_pool_size=2**30,
)
device = (int(os.environ['LOCAL_RANK']))

rmm.mr.set_per_device_resource(device, pool)

This is almost correct. However, when creating a memory resource, you should do so with the target device active.

That is, you must call cudaSetDevice(device) before creating the pool and then calling set_per_device_resource.

This can be done like so:

import rmm

device = (int(os.environ['LOCAL_RANK']))
rmm._cuda.gpu.setDevice(device)

pool = rmm.mr.PoolMemoryResource(...)

rmm.mr.set_per_device_resource(device, pool)

Since this is such a common pattern, the top-level rmm.reinitialize has some logic to handle this for you:

import rmm
device = int(os.environ["LOCAL_RANK"])

rmm.reinitialize(devices=device, pool_allocator=True, initial_pool_size=...)

This doesn't have quite as much flexibility on the set up of the allocator, but if you just need a pool on top of a cuda memory resource then it works fine.

We could add an interface whereby you provide a zero-argument callback to construct the pool (and rmm.reinitialize would arrange to call it with the correct device active), but we haven't needed it so far.

@wence- wence- added question Further information is requested doc Documentation and removed bug Something isn't working ? - Needs Triage Need team to review and classify labels Dec 7, 2023
@li-yi-dong
Copy link
Author

@wence- Thanks for your reply!
I tried with

torch.cuda.change_current_allocator(rmm_torch_allocator)
device = (int(os.environ['LOCAL_RANK']))
rmm._cuda.gpu.setDevice(device)

pool = rmm.mr.PoolMemoryResource(
    rmm.mr.CudaMemoryResource(),
    initial_pool_size=2**30,
)

rmm.mr.set_per_device_resource(device, pool)

Unfortunately, it coredumped like
image

@wence-
Copy link
Contributor

wence- commented Dec 8, 2023

Hmm, we fixed some bugs around stream ordered memory resources that will
be in 23.12, but are not 23.10. It's possible that using 23.12 will fix things.

Can you provide a complete example script to run and I will try and reproduce locally.

@li-yi-dong
Copy link
Author

Hmm, we fixed some bugs around stream ordered memory resources that will be in 23.12, but are not 23.10. It's possible that using 23.12 will fix things.

Can you provide a complete example script to run and I will try and reproduce locally.

Let me try the 23.12

@wence-
Copy link
Contributor

wence- commented Dec 8, 2023

I tried this trivial code:

import os

import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch

device = (int(os.environ['LOCAL_RANK']))
torch.cuda.change_current_allocator(rmm_torch_allocator)
rmm._cuda.gpu.setDevice(device)

pool = rmm.mr.PoolMemoryResource(
    rmm.mr.CudaMemoryResource(),
    initial_pool_size=2**30,
)

rmm.mr.set_per_device_resource(device, pool)

print(torch.zeros(2, 3, device=f"cuda:{device}"))

When I run with torchrun --nnodes 1 --nproc-per-node gpu test.py I don't get any errors (both with 23.10 and 23.12)

@li-yi-dong
Copy link
Author

I tried this trivial code:

import os

import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch

device = (int(os.environ['LOCAL_RANK']))
torch.cuda.change_current_allocator(rmm_torch_allocator)
rmm._cuda.gpu.setDevice(device)

pool = rmm.mr.PoolMemoryResource(
    rmm.mr.CudaMemoryResource(),
    initial_pool_size=2**30,
)

rmm.mr.set_per_device_resource(device, pool)

print(torch.zeros(2, 3, device=f"cuda:{device}"))

When I run with torchrun --nnodes 1 --nproc-per-node gpu test.py I don't get any errors (both with 23.10 and 23.12)

Emmm, I tried this code and got some interesting results.

I run the sample code with torchrun --nnodes 1 --nproc-per-node 8 rmm_test.py
It cored

[2023-12-11 16:55:07,277] torch.distributed.run: [WARNING]
[2023-12-11 16:55:07,277] torch.distributed.run: [WARNING] *****************************************
[2023-12-11 16:55:07,277] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
[2023-12-11 16:55:07,277] torch.distributed.run: [WARNING] *****************************************
tensor([[0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')
tensor([[0., 0., 0.],
        [0., 0., 0.]], device='cuda:1')
tensor([[0., 0., 0.],
        [0., 0., 0.]], device='cuda:2')
[2023-12-11 16:55:17,290] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -11) local_rank: 3 (pid: 216181) of binary: /opt/conda/envs/rmm_dev2/bin/python3.10
Traceback (most recent call last):
  File "/opt/conda/envs/rmm_dev2/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

The core file:

#0  0x00007fd2b6b2e0b8 in ?? () from /lib64/libcuda.so.1
#1  0x00007fd2b6973fd7 in ?? () from /lib64/libcuda.so.1
#2  0x00007fd2f7621a58 in ?? () from /opt/conda/envs/rmm_dev2/lib/libcudart.so.12
#3  0x00007fd2f767901b in cudaEventRecord () from /opt/conda/envs/rmm_dev2/lib/libcudart.so.12
#4  0x00007fd2f75bcb56 in rmm::mr::detail::stream_ordered_memory_resource<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>, rmm::mr::detail::coalescing_free_list>::do_deallocate(void*, unsigned long, rmm::cuda_stream_view) () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/rmm/_lib/memory_resource.cpython-310-x86_64-linux-gnu.so
#5  0x00007fd247d4e3ce in deallocate () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/rmm/_lib/torch_allocator.cpython-310-x86_64-linux-gnu.so
#6  0x00007fd2b60460e3 in torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_python.so
#7  0x00007fd2b5979c06 in c10::StorageImpl::~StorageImpl() () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_python.so
#8  0x00007fd272b58ca7 in c10::intrusive_ptr<c10::StorageImpl, c10::detail::intrusive_target_default_null_type<c10::StorageImpl> >::reset_() ()
   from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libc10.so
#9  0x00007fd272b50cb3 in c10::TensorImpl::~TensorImpl() () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libc10.so
#10 0x00007fd272b50e49 in c10::TensorImpl::~TensorImpl() () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libc10.so
#11 0x00007fd2a02a1a64 in at::native::isfinite(at::Tensor const&)::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const [clone .isra.0] ()
   from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#12 0x00007fd2a02a2777 in at::native::isfinite(at::Tensor const&) () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#13 0x00007fd2a1228ddd in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__isfinite>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) ()
   from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#14 0x00007fd2a0c982eb in at::_ops::isfinite::call(at::Tensor const&) () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_cpu.so
#15 0x00007fd2b5a79453 in torch::autograd::THPVariable_isfinite(_object*, _object*, _object*) () from /opt/conda/envs/rmm_dev2/lib/python3.10/site-packages/torch/lib/libtorch_python.so
#16 0x00005594adc107e6 in cfunction_call (func=0x7fd2f2f00e00, args=<optimized out>, kwargs=<optimized out>) at /usr/local/src/conda/python-3.10.13/Objects/methodobject.c:543

I tried multiple times and noticed that every time it could only print the tensors on device 0,1,2. So, I tried with torchrun --nnodes 1 --nproc-per-node 3 test.py, and it worked just fine. It worked fine when number of gpu is under 4. When the number of gpu >= 4, it cored.

I tried with RMM v23.12.00, Python 3.10 and PyTorch 2.1.1

@wence-
Copy link
Contributor

wence- commented Dec 11, 2023

Thanks, I'll try and reproduce on a system with more than two GPUs.

@wence- wence- self-assigned this Dec 11, 2023
@wence- wence- added the bug Something isn't working label Dec 11, 2023
@harrism
Copy link
Member

harrism commented Dec 12, 2023

Is it possible that the active device could be changing before the deallocate is (implicitly) called? The error in cudaEventRecord makes me think that it may be trying to record an event on the wrong device. This MR expects the device that was active when the pool was created to be active when any call to allocate() or deallocate() is made.

@li-yi-dong
Copy link
Author

Is it possible that the active device could be changing before the deallocate is (implicitly) called?

I don’t think the trivial code nor PyTorch would do so. That could not explain why less than 4 GPUs worked.

I modified the code into

import os
import time
import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch

device = (int(os.environ['LOCAL_RANK']))
torch.cuda.change_current_allocator(rmm_torch_allocator)
rmm._cuda.gpu.setDevice(device)

print(rmm._cuda.gpu.getDevice())
pool = rmm.mr.PoolMemoryResource(
    rmm.mr.CudaMemoryResource(),
    initial_pool_size=2**30,
)

rmm.mr.set_per_device_resource(device, pool)
a = torch.zeros(2, 3, device=f"cuda:{device}")
print(a)
print(rmm._cuda.gpu.getDevice())
del a
time.sleep(5)
print(rmm._cuda.gpu.getDevice())

The output

(rmm_dev2) sh rmm.sh
[2023-12-12 14:33:28,740] torch.distributed.run: [WARNING]
[2023-12-12 14:33:28,740] torch.distributed.run: [WARNING] *****************************************
[2023-12-12 14:33:28,740] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
[2023-12-12 14:33:28,740] torch.distributed.run: [WARNING] *****************************************
6
3
7
2
1
0
4
5
tensor([[0., 0., 0.],
        [0., 0., 0.]], device='cuda:2')
2
tensor([[0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')
0
tensor([[0., 0., 0.],
        [0., 0., 0.]], device='cuda:1')
1
[2023-12-12 14:33:38,751] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 236161 closing signal SIGTERM
[2023-12-12 14:33:38,751] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 236162 closing signal SIGTERM
[2023-12-12 14:33:38,751] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 236163 closing signal SIGTERM
[2023-12-12 14:33:39,029] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -11) local_rank: 3 (pid: 236164) of binary: /opt/conda/envs/rmm_dev2/bin/python3.10

It seems that the tensors on devices 3,4,5,6,7 has been deallocated before print. (print the tensor on GPU will synchronize the CPU and GPU in PyTorch)

@wence-
Copy link
Contributor

wence- commented Dec 12, 2023

I was able to reproduce running with four GPUs, I have yet to figure out what is going on. Debugging under gdb is difficult here because torchrun is running things in processes, but. If we run in gdb with set detach-on-fork off and set follow-fork-mode child, eventually we can get to the relevant process and I can get a backtrace.

Next step is to build RMM in debug mode so I have some symbols to inspect.

This is what I have right now to debug, note I only need to allocate things on a single device:

import os
import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch

device = (int(os.environ['LOCAL_RANK']))
if device == 3:
    torch.cuda.change_current_allocator(rmm_torch_allocator)
    rmm._cuda.gpu.setDevice(device)
    pool = rmm.mr.PoolMemoryResource(
        rmm.mr.CudaMemoryResource(),
        initial_pool_size=2**30,
    )
    rmm.mr.set_per_device_resource(device, pool)
    tensor = torch.zeros(2, 3, device=f"cuda:{device}")
    print(torch.cuda.current_device(), device, os.getpid(), flush=True)
    print(tensor, flush=True)
    del tensor

So my suspicion is that torch shuffling cuda devices out from under us in a bad way.

@harrism
Copy link
Member

harrism commented Dec 12, 2023

Thanks so much for debugging, @wence- .

@wence-
Copy link
Contributor

wence- commented Dec 12, 2023

OK, I have the culprit.

The signature we offer for the plug in allocation functions is:

void *allocate(size_t size, int device, cudaStream_t stream);
void deallocate(void *ptr, size_t size, cudaStream_t stream);

Which was the original signature for the pluggable allocators when we introduced this in #1168, introduced in pytorch in pytorch/pytorch#86786

But soon after, in pytorch/pytorch#91398 the signatures were changed to:

void *allocate(size_t size, int device, cudaStream_t stream);
void deallocate(void *ptr, size_t size, int device, cudaStream_t stream);

Note the change to also accept the device in the deallocate function.

So we're getting 3 (as device), interpreting it as a stream and trying to use that in the RMM deallocation function. But of course that stream is nonsense, everyone is actually just using the default 0 stream.

The fix is the fix the signature in RMM (I will prepare a patch).

wence- added a commit to wence-/rmm that referenced this issue Dec 12, 2023
Since pytorch/pytorch#91398, the signature of
the pluggable allocate and deallocate functions must accept the device
id. The current version only accepts a device id for allocate, which
means that when using a stream ordered allocator with devices other
than device zero, we pass an invalid stream into the deallocation
function. To fix this, adapt the signature to match the one pytorch
expects.

Now, since we have the device available during allocation and
deallocation, we would like to use that device to obtain the
appropriate memory resource.

Unfortunately, since RMM's cuda_device_id does not have a nullary
constructor, we can't use it in Cython without some hacky workarounds.

However, since we don't actually need to build a Python module, but
rather just a single shared library that offers two extern "C"
functions, let's just write our allocator hooks directly in C++.

- Closes rapidsai#1405
@wence-
Copy link
Contributor

wence- commented Dec 12, 2023

Here is a minimal diff that will allow your code to run:

diff --git a/python/rmm/_lib/torch_allocator.pyx b/python/rmm/_lib/torch_allocator.pyx
index 12dc9fe1..2b11028c 100644
--- a/python/rmm/_lib/torch_allocator.pyx
+++ b/python/rmm/_lib/torch_allocator.pyx
@@ -15,7 +15,7 @@ cdef public void* allocate(
     return mr[0].allocate(size, stream_view)
 
 cdef public void deallocate(
-    void* ptr, ssize_t size, void* stream
+    void* ptr, ssize_t size, int device, void* stream
 ) except * with gil:
     cdef device_memory_resource* mr = get_current_device_resource()
     cdef cuda_stream_view stream_view = cuda_stream_view(

However, in #1407 I am trying to do a better thing, which is to use the memory resource associated with the device we are being passed, rather than just assuming that get_current_device_resource will return the correct resource. That needs some help from someone on the build side of things

wence- added a commit to wence-/rmm that referenced this issue Dec 12, 2023
The deallocation function now also takes the device id.

Since both halves of the pair now receive the device on which to
perform the (de)allocation, we switch from using
get_current_device_resource to using the (more correct)
get_per_device_resource. This necessitates a workaround in Cython:
rmm::cuda_device_id has no nullary constructor, and so cannot be
stack-allocated the way Cython transpiles code. Instead perform a heap
allocation and then delete it.

- Closes rapidsai#1405
wence- added a commit to wence-/rmm that referenced this issue Dec 12, 2023
Since pytorch/pytorch#91398, the signature of
the pluggable allocate and deallocate functions must accept the device
id. The current version only accepts a device id for allocate, which
means that when using a stream ordered allocator with devices other
than device zero, we pass an invalid stream into the deallocation
function. To fix this, adapt the signature to match the one pytorch
expects.

Now, since we have the device available during allocation and
deallocation, we would like to use that device to obtain the
appropriate memory resource.

Unfortunately, since RMM's cuda_device_id does not have a nullary
constructor, we can't use it in Cython without some hacky workarounds.

However, since we don't actually need to build a Python module, but
rather just a single shared library that offers two extern "C"
functions, let's just write our allocator hooks directly in C++.

- Closes rapidsai#1405
@wence-
Copy link
Contributor

wence- commented Dec 12, 2023

Can you try if the code in #1408 works for you @li-yi-dong?

wence- added a commit to wence-/rmm that referenced this issue Dec 12, 2023
Since pytorch/pytorch#91398, the signature of
the pluggable allocate and deallocate functions must accept the device
id. The current version only accepts a device id for allocate, which
means that when using a stream ordered allocator with devices other
than device zero, we pass an invalid stream into the deallocation
function. To fix this, adapt the signature to match the one pytorch
expects.

Now, since we have the device available during allocation and
deallocation, we would like to use that device to obtain the
appropriate memory resource.

Unfortunately, since RMM's cuda_device_id does not have a nullary
constructor, we can't use it in Cython without some hacky workarounds.

However, since we don't actually need to build a Python module, but
rather just a single shared library that offers two extern "C"
functions, let's just write our allocator hooks directly in C++.

- Closes rapidsai#1405
wence- added a commit to wence-/rmm that referenced this issue Dec 12, 2023
Since pytorch/pytorch#91398, the signature of
the pluggable allocate and deallocate functions must accept the device
id. The current version only accepts a device id for allocate, which
means that when using a stream ordered allocator with devices other
than device zero, we pass an invalid stream into the deallocation
function. To fix this, adapt the signature to match the one pytorch
expects.

Now, since we have the device available during allocation and
deallocation, we would like to use that device to obtain the
appropriate memory resource.

Unfortunately, since RMM's cuda_device_id does not have a nullary
constructor, we can't use it in Cython without some hacky workarounds.

However, since we don't actually need to build a Python module, but
rather just a single shared library that offers two extern "C"
functions, let's just write our allocator hooks directly in C++.

- Closes rapidsai#1405
wence- added a commit to wence-/rmm that referenced this issue Dec 12, 2023
Since pytorch/pytorch#91398, the signature of
the pluggable allocate and deallocate functions must accept the device
id. The current version only accepts a device id for allocate, which
means that when using a stream ordered allocator with devices other
than device zero, we pass an invalid stream into the deallocation
function. To fix this, adapt the signature to match the one pytorch
expects.

Now, since we have the device available during allocation and
deallocation, we would like to use that device to obtain the
appropriate memory resource.

Unfortunately, since RMM's cuda_device_id does not have a nullary
constructor, we can't use it in Cython without some hacky workarounds.

However, since we don't actually need to build a Python module, but
rather just a single shared library that offers two extern "C"
functions, let's just write our allocator hooks directly in C++.

- Closes rapidsai#1405
@li-yi-dong
Copy link
Author

Can you try if the code in #1408 works for you @li-yi-dong?

I works pretty smooth with my task. And the RMM really outperforms the PyTorch caching allocator in terms of fragmentation.

@wence-
Copy link
Contributor

wence- commented Dec 13, 2023

Great, thanks! In the end we are going with the code in #1407 which I hope very much also works identically, if you could confirm that would be wonderful.

@li-yi-dong
Copy link
Author

Great, thanks! In the end we are going with the code in #1407 which I hope very much also works identically, if you could confirm that would be wonderful.

It works fine.

rapids-bot bot pushed a commit that referenced this issue Dec 14, 2023
Since pytorch/pytorch#91398, the signature of the pluggable allocate and deallocate functions must accept the device id. The current version only accepts a device id for allocate, which means that when using a stream ordered allocator with devices other than device zero, we pass an invalid stream into the deallocation function. To fix this, adapt the signature to match the one pytorch expects.

Now, since we have the device available during allocation and deallocation, we would like to use that device to obtain the appropriate memory resource.

Unfortunately, since RMM's cuda_device_id does not have a nullary constructor, we can't use it in Cython without some hacky workarounds.

However, since we don't actually need to build a Python module, but rather just a single shared library that offers two extern "C" functions, let's just write our allocator hooks directly in C++.

- Closes #1405

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Mark Harris (https://github.com/harrism)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #1407
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working doc Documentation question Further information is requested
Projects
None yet
3 participants