Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] HuggingFace/Pytorch with dask-cuda- worker does not free memory #383

Closed
VibhuJawa opened this issue Aug 19, 2020 · 12 comments
Closed

Comments

@VibhuJawa
Copy link
Member

VibhuJawa commented Aug 19, 2020

When I run the same exact example within a dask-cuda worker memory is not freed but it is freed when I run it without it.

Below Works :

def run_pytorch_func():
    # install transformers by
    # pip install transformers
    from transformers import AutoModelForTokenClassification
    import gc
    import torch

    model_path = 'bert-base-cased'
    model = AutoModelForTokenClassification.from_pretrained(model_path)
    model = model.cuda()
    model = model.eval()
    with torch.no_grad():
        token_tensor = torch.randint(high=1000,size=(200,256)).long().cuda()
        output = model(token_tensor)

    del model
    del token_tensor
    del output
    torch.cuda.empty_cache()
    gc.collect()
    return None

run_pytorch_func()
# !nvidia-smi | head -n 10
# Memory occupied:  1082MiB 
import rmm
rmm.reinitialize(pool_allocator=True,initial_pool_size=30e+9)
# !nvidia-smi | head -n 10
# Memory occupied:  29732MiB

Below Fails :

from dask_cuda import LocalCUDACluster
from dask.distributed import Client

cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=[0])
client = Client(cluster)


def run_pytorch_func():
    # install transformers by
    # pip install transformers
    from transformers import AutoModelForTokenClassification
    import gc
    import torch

    model_path = 'bert-base-cased'
    model = AutoModelForTokenClassification.from_pretrained(model_path)
    model = model.cuda()
    model = model.eval()
    with torch.no_grad():
        token_tensor = torch.randint(high=1000,size=(200,256)).long().cuda()
        output = model(token_tensor)

    del model
    del token_tensor
    del output
    torch.cuda.empty_cache()
    gc.collect()
    return None

import gc
client.run(run_pytorch_func)
client.run(gc.collect)
client.run_on_scheduler(gc.collect)
gc.collect()


### !nvidia-smi |  head -n 10
## Memory occupied: 5703MiB

# Below Fails

import rmm
client.run(rmm.reinitialize,pool_allocator=True,initial_pool_size=30e+9)
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-707196a41be2> in <module>
      1 # Below OOMs
      2 import rmm
----> 3 client.run(rmm.reinitialize,pool_allocator=True,initial_pool_size=30e+9)

/raid/vjawa/conda/envs/tpcx-bb-aug-19-torch/lib/python3.7/site-packages/distributed/client.py in run(self, function, *args, **kwargs)
   2490         >>> c.run(print_state, wait=False)  # doctest: +SKIP
   2491         """
-> 2492         return self.sync(self._run, function, *args, **kwargs)
   2493 
   2494     def run_coroutine(self, function, *args, **kwargs):

/raid/vjawa/conda/envs/tpcx-bb-aug-19-torch/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    831         else:
    832             return sync(
--> 833                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    834             )
    835 

/raid/vjawa/conda/envs/tpcx-bb-aug-19-torch/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    337     if error[0]:
    338         typ, exc, tb = error[0]
--> 339         raise exc.with_traceback(tb)
    340     else:
    341         return result[0]

/raid/vjawa/conda/envs/tpcx-bb-aug-19-torch/lib/python3.7/site-packages/distributed/utils.py in f()
    321             if callback_timeout is not None:
    322                 future = asyncio.wait_for(future, callback_timeout)
--> 323             result[0] = yield future
    324         except Exception as exc:
    325             error[0] = sys.exc_info()

/raid/vjawa/conda/envs/tpcx-bb-aug-19-torch/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/raid/vjawa/conda/envs/tpcx-bb-aug-19-torch/lib/python3.7/site-packages/distributed/client.py in _run(self, function, nanny, workers, wait, *args, **kwargs)
   2427             elif resp["status"] == "error":
   2428                 typ, exc, tb = clean_exception(**resp)
-> 2429                 raise exc.with_traceback(tb)
   2430         if wait:
   2431             return results

/raid/vjawa/conda/envs/tpcx-bb-aug-19-torch/lib/python3.7/site-packages/rmm/rmm.py in reinitialize()
     69         devices=devices,
     70         logging=logging,
---> 71         log_file_name=log_file_name,
     72     )
     73 

rmm/_lib/memory_resource.pyx in rmm._lib.memory_resource._initialize()

rmm/_lib/memory_resource.pyx in rmm._lib.memory_resource._initialize()

rmm/_lib/memory_resource.pyx in rmm._lib.memory_resource.PoolMemoryResource.__cinit__()

RuntimeError: RMM failure at: ../include/rmm/mr/device/pool_memory_resource.hpp:100: Initial pool size exceeds the maximum pool size!

Minimal Gists

CC: @jakirkham / @randerzander .

@quasiben
Copy link
Member

I am seeing something similar. Oddly, I see two CUDA CONTEXT creations happening on Device 0 (two processes). Note, this only happens with LocalCUDACluster. When starting a worker manually with dask-cuda-worker I see just one process on device 0

@quasiben
Copy link
Member

I think the second device is the client process initializing as well so I don't think this is a big concern

@quasiben
Copy link
Member

I suspect torch things are not being cleaned up nicely with dask-cuda but I don't know why. As a test I re-ran client.run(torch.cuda.empty_cache) and this cleared out more GPU memory. @VibhuJawa can you try this as well and see if the additional empty_cache call frees up more memory for you ?

@jakirkham
Copy link
Member

jakirkham commented Aug 19, 2020

Is it possible this is related to the same Numba issue ( numba/numba#6147 )? Thinking about the multiple contexts on the same device. If so, could you please try downgrading to numba=0.50.0?

Also Peter made a fix yesterday ( #379 ) that we should make sure we are getting.

@VibhuJawa
Copy link
Member Author

client.run(torch.cuda.empty_cache)

Will try it, thanks for the tip.

Is it possible this is related to the same Numba issue ( numba/numba#6147 )? Thinking about the multiple contexts on the same device. If so, could you please try downgrading to numba=0.50.0?

So I was on numba=0.50.0 so that should not be a problem.

Also Peter made a fix yesterday ( #379 ) that we should make sure we are getting.

Let me update my env and rerun it.

Thanks a lot for the support guys.

@jakirkham
Copy link
Member

Thanks @VibhuJawa! Please let us know how it goes 🙂

@jakirkham
Copy link
Member

Were you able to make any progress here Vibhu or are you still stuck?

@VibhuJawa
Copy link
Member Author

Were you able to make any progress here Vibhu or are you still stuck?

Still stuck on it, This does not sadly seem to fix it for the workflow at rapidsai/gpu-bdb#84 , I will try to take some time to get you guys a better repro.

Sorry for the delay on this, was pulled into other things.

@jakirkham
Copy link
Member

No worries. Thanks for the update 🙂

@VibhuJawa
Copy link
Member Author

As an update on this below cleans up extra dask-cuda related memory

client.run(torch.cuda.empty_cache) 

But there still is some cleanup issues happening (without dask-cuda). Unsure whether this is coming from huggingface/pytorch.

Gist: https://gist.github.com/VibhuJawa/bd06afceef8960ce5b99026c14ecac8e

Example:

from transformers import AutoModelForTokenClassification
import gc
import torch

model_path = 'bert-base-cased'
model = AutoModelForTokenClassification.from_pretrained(model_path)
model = model.cuda()
model = model.eval()
with torch.no_grad():
    token_tensor = torch.randint(high=1000,size=(200,256)).long().cuda()
    output = model(token_tensor)

del model
del token_tensor
del output
gc.collect()
torch.cuda.empty_cache()
!nvidia-smi |  head -n 10
Mon Aug 31 18:35:00 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64.00    Driver Version: 440.64.00    CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla T4            On   | 00000000:3B:00.0 Off |                    0 |
| N/A   55C    P0    29W /  70W |   1066MiB / 15109MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+

Below OOMs

import rmm
rmm.reinitialize(pool_allocator=True,initial_pool_size=15e+9)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-53991fa7fde5> in <module>
      1 import rmm
      2 rmm.reinitialize(pool_allocator=True,
----> 3                  initial_pool_size=15e+9)

/nvme/0/vjawa/conda/envs/tpcxbb-aug-31-pytorch/lib/python3.7/site-packages/rmm/rmm.py in reinitialize(pool_allocator, managed_memory, initial_pool_size, maximum_pool_size, devices, logging, log_file_name)
     75         devices=devices,
     76         logging=logging,
---> 77         log_file_name=log_file_name,
     78     )
     79 

rmm/_lib/memory_resource.pyx in rmm._lib.memory_resource._initialize()

rmm/_lib/memory_resource.pyx in rmm._lib.memory_resource._initialize()

rmm/_lib/memory_resource.pyx in rmm._lib.memory_resource.PoolMemoryResource.__cinit__()

RuntimeError: RMM failure at: ../include/rmm/mr/device/pool_memory_resource.hpp:100: Initial pool size exceeds the maximum pool size!

I am closing this issue here and will raise something on PyTorch or hugging face as the rest is non dask-cuda related.

@quasiben
Copy link
Member

quasiben commented Sep 1, 2020

@VibhuJawa my guess here is that things blow up because pytorch isn't using RMM. Folks have filed an issue for using external allocators (like rmm) with pytorch: pytorch/pytorch#43144

@VibhuJawa
Copy link
Member Author

VibhuJawa commented Sep 1, 2020

@VibhuJawa my guess here is that things blow up because pytorch isn't using RMM. Folks have filed an issue for using external allocators (like rmm) with PyTorch: pytorch/pytorch#43144

Yup, FWIW, I won't need memory to be freed this aggresively if PyTorch was working with RMM.

The current workflow is as follows:

  • Run a non-PyTorch workflow with max rmm Pool
  • Run a PyTorch workflow with less rmm pool.
  • Run a non-PyTorch workflow with max rmm Pool

And we want to the above without restarting workers/client as these restarts often are finicky on our lab machines especially at scale and can take up to 2+ minutes (if they work correctly)

If we had RMM pool working as an external allocator with PyTorch we could have just 1 pool that gets re-used making workflow like above much more straightforward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants