Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] GPU out-of-memory errors when applying UMAP to extremely large SAE feature matrices #6167

Open
lc82111 opened this issue Dec 8, 2024 · 3 comments
Labels
? - Needs Triage Need team to review and classify question Further information is requested

Comments

@lc82111
Copy link

lc82111 commented Dec 8, 2024

Hi beckernick,

I'm encountering GPU out-of-memory errors when applying UMAP to extremely large SAE feature matrices (~400GB in fp32 format). My environment details are as follows:

OS: Ubuntu 20.04
RAPIDS: 24.12
GPU: 3090
CPU Memory: 512 GB
Python: 3.12.8
I suspect that setting nnd_n_clusters=1 might help with the memory usage. Could you provide guidance or share any implementation strategies that could handle such large-scale feature matrices efficiently?

Here is a list of installed packages for reference:


Package                   Version
------------------------- --------------
accelerate                1.1.1
aiohappyeyeballs          2.4.4
aiohttp                   3.11.9
aiosignal                 1.3.1
anyio                     4.6.2.post1
argon2-cffi               23.1.0
argon2-cffi-bindings      21.2.0
arrow                     1.3.0
asttokens                 3.0.0
async-lru                 2.0.4
attrs                     24.2.0
babel                     2.16.0
beautifulsoup4            4.12.3
bleach                    6.2.0
bokeh                     3.6.2
branca                    0.7.2
Brotli                    1.1.0
cached-property           1.5.2
cachetools                5.5.0
certifi                   2024.8.30
cffi                      1.17.1
charset-normalizer        3.4.0
click                     8.1.7
cloudpickle               3.1.0
colorama                  0.4.6
colorcet                  3.1.0
comm                      0.2.2
confluent-kafka           2.5.3
contourpy                 1.3.1
cucim                     24.12.0a12
cuda-python               12.6.0
cudf                      24.12.0a400
cudf_kafka                24.12.0a400
cudf-polars               24.12.0a400
cugraph                   24.12.0a86
cuml                      24.12.0a43
cuproj                    24.12.0a23
cupy                      13.3.0
cuspatial                 24.12.0a23
custreamz                 24.12.0a400
cuvs                      24.12.0a91
cuxfilter                 24.12.0a14
cycler                    0.12.1
cytoolz                   1.0.0
dask                      2024.11.2
dask-cuda                 24.12.0a15
dask-cudf                 24.12.0a400
dask-expr                 1.1.19
datashader                0.16.3
debugpy                   1.8.9
decorator                 5.1.1
defusedxml                0.7.1
distributed               2024.11.2
distributed-ucxx          0.41.0
entrypoints               0.4
exceptiongroup            1.2.2
executing                 2.1.0
fastjsonschema            2.21.1
fastrlock                 0.8.2
filelock                  3.16.1
folium                    0.18.0
fonttools                 4.55.1
fqdn                      1.5.1
frozenlist                1.5.0
fsspec                    2024.10.0
geopandas                 1.0.1
gmpy2                     2.1.5
h11                       0.14.0
h2                        4.1.0
h5py                      3.12.1
holoviews                 1.20.0
hpack                     4.0.0
httpcore                  1.0.7
httpx                     0.28.0
huggingface-hub           0.26.3
hyperframe                6.0.1
idna                      3.10
imagecodecs               2024.9.22
imageio                   2.36.1
importlib_metadata        8.5.0
importlib_resources       6.4.5
ipykernel                 6.29.5
ipython                   8.30.0
isoduration               20.11.0
jedi                      0.19.2
Jinja2                    3.1.4
joblib                    1.4.2
json5                     0.10.0
jsonpointer               3.0.0
jsonschema                4.23.0
jsonschema-specifications 2024.10.1
jupyter_client            8.6.3
jupyter_core              5.7.2
jupyter-events            0.10.0
jupyter-lsp               2.2.5
jupyter_server            2.14.2
jupyter_server_proxy      4.4.0
jupyter_server_terminals  0.5.3
jupyterlab                4.3.2
jupyterlab_pygments       0.3.0
jupyterlab_server         2.27.3
kiwisolver                1.4.7
lazy_loader               0.4
linkify-it-py             2.0.3
llvmlite                  0.43.0
locket                    1.0.0
lz4                       4.3.3
mapclassify               2.8.1
Markdown                  3.6
markdown-it-py            3.0.0
MarkupSafe                3.0.2
matplotlib                3.9.3
matplotlib-inline         0.1.7
mdit-py-plugins           0.4.2
mdurl                     0.1.2
mistune                   3.0.2
mpmath                    1.3.0
msgpack                   1.1.0
multidict                 6.1.0
multipledispatch          0.6.0
munkres                   1.1.4
nbclient                  0.10.1
nbconvert                 7.16.4
nbformat                  5.10.4
nest_asyncio              1.6.0
networkx                  3.4.2
notebook_shim             0.2.4
numba                     0.60.0
numba-cuda                0.0.17
numpy                     2.0.2
nvtx                      0.2.10
nx-cugraph                24.12.0a85
overrides                 7.7.0
packaging                 24.2
pandas                    2.2.2
pandocfilters             1.5.0
panel                     1.5.4
param                     2.1.1
parso                     0.8.4
partd                     1.4.2
pexpect                   4.9.0
pickleshare               0.7.5
pillow                    11.0.0
pip                       24.3.1
pkgutil_resolve_name      1.3.10
platformdirs              4.3.6
polars                    1.14.0
prometheus_client         0.21.1
prompt_toolkit            3.0.48
propcache                 0.2.1
psutil                    6.1.0
ptyprocess                0.7.0
pure_eval                 0.2.3
pyarrow                   17.0.0
pycparser                 2.22
pyct                      0.5.0
Pygments                  2.18.0
pylibcudf                 24.12.0a400
pylibcugraph              24.12.0a86
pylibraft                 24.12.0a46
pynvjitlink-cu12          0.4.0
pynvml                    11.5.3
pyogrio                   0.10.0
pyparsing                 3.2.0
pyproj                    3.7.0
PySocks                   1.7.1
python-dateutil           2.9.0.post0
python-json-logger        2.0.7
pytz                      2024.2
pyviz_comms               3.0.3
PyWavelets                1.8.0
PyYAML                    6.0.2
pyzmq                     26.2.0
raft-dask                 24.12.0a46
rapids-dask-dependency    24.12.0a0
referencing               0.35.1
regex                     2024.11.6
requests                  2.32.3
rfc3339-validator         0.1.4
rfc3986-validator         0.1.1
rich                      13.9.4
rmm                       24.12.0a33
rpds-py                   0.22.3
safetensors               0.4.5
scikit-image              0.24.0
scikit-learn              1.5.2
scipy                     1.14.1
Send2Trash                1.8.3
setuptools                75.6.0
shapely                   2.0.6
simpervisor               1.0.0
six                       1.17.0
sniffio                   1.3.1
sortedcontainers          2.4.0
soupsieve                 2.5
stack-data                0.6.2
streamz                   0.6.4
sympy                     1.13.3
tblib                     3.0.0
terminado                 0.18.1
threadpoolctl             3.5.0
tifffile                  2024.9.20
tinycss2                  1.4.0
tokenizers                0.21.0
tomli                     2.2.1
toolz                     1.0.0
torch                     2.5.1.post303
tornado                   6.4.2
tqdm                      4.67.1
traitlets                 5.14.3
transformers              4.47.0
treelite                  4.3.0
types-python-dateutil     2.9.0.20241003
typing_extensions         4.12.2
typing_utils              0.1.0
tzdata                    2024.2
uc-micro-py               1.0.3
ucx-py                    0.41.0a13
ucxx                      0.41.0
unicodedata2              15.1.0
uri-template              1.3.0
urllib3                   2.2.3
wcwidth                   0.2.13
webcolors                 24.11.1
webencodings              0.5.1
websocket-client          1.8.0
wheel                     0.45.1
xarray                    2024.11.0
xgboost                   2.1.2
xyzservices               2024.9.0
yarl                      1.18.3
zict                      3.0.0
zipp                      3.21.0
zstandard                 0.23.0

Thank you for your assistance!

Best regards,

@lc82111 lc82111 added ? - Needs Triage Need team to review and classify question Further information is requested labels Dec 8, 2024
@dantegd
Copy link
Member

dantegd commented Dec 12, 2024

Hi @lc82111, that is quite a large dataset for a 3090! I saw the discussion from Apply UMAP to the SAEs features RE-N-Y/sae#1, will be adding the details here and we will be looking into things.

@cjnolet
Copy link
Member

cjnolet commented Dec 12, 2024

@lc82111 We just recently released a new algorithm for being able to scale massive datasets which are larger than the memory available in the GPU. The algorithm works by breaking apart the dataset into some number of partitions (using kmeans as a clustering algorithm) so that each partition CAN fit on the GPU.

The ideal setting for nnd_n_clusters would create enough partitions so that the data assigned to the partitoin can fit on the GPU. Too many custers would not provide enough work for the GPU to do, and excess amount of time would be spent copying the clusters back and forth from the GPU memory to the RAM memory.

Are the out of memory errors you are getting happening on the GPU or in RAM? You have 512GB of RAM available but the nn-descent partitioning algorithm will still require the data be available in RAM. It's possible you could try to use memory mapping for this, but I do caution that we have not tried this.

That being said, I would maybe try 25-30 partitions so that each partition contains ~13-16GB of data. If that still doesn't work, you can try increasing the number of partitions further. Please let us know if we can help futher!

@cjnolet
Copy link
Member

cjnolet commented Dec 12, 2024

@lc82111 In case you haven't seen it yet, I also wanted to bring your attention to a recent blog that explains the scaling algorithm in more detail: https://developer.nvidia.com/blog/even-faster-and-more-scalable-umap-on-the-gpu-with-rapids-cuml/.

Prior to this feature, it was expected that the algorithm would scale by training a UMAP embedding on a smaller subsample of the data and then using transform() to embed the rest of the data.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage Need team to review and classify question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants