Skip to content

Commit

Permalink
Refactor rmm usage in cudf.pandas (#16021)
Browse files Browse the repository at this point in the history
This PR addresses review comments made by @bdice here: #15628 (review)

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #16021
  • Loading branch information
galipremsagar authored Jun 13, 2024
1 parent 3cb3df3 commit 3f8f214
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions python/cudf/cudf/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import os
import warnings

import rmm.mr

from .fast_slow_proxy import is_proxy_object
from .magics import load_ipython_extension
from .profiler import Profiler
Expand All @@ -22,20 +24,16 @@ def install():
loader = ModuleAccelerator.install("pandas", "cudf", "pandas")
global LOADED
LOADED = loader is not None
import os

if (rmm_mode := os.getenv("CUDF_PANDAS_RMM_MODE", None)) is not None:
import rmm.mr
from rmm.mr import available_device_memory

# Check if a non-default memory resource is set
current_mr = rmm.mr.get_current_device_resource()
if not isinstance(current_mr, rmm.mr.CudaMemoryResource):
warnings.warn(
f"cudf.pandas detected an already configured memory resource, ignoring 'CUDF_PANDAS_RMM_MODE'={str(rmm_mode)}",
UserWarning,
)
free_memory, _ = available_device_memory()
free_memory, _ = rmm.mr.available_device_memory()
free_memory = int(round(float(free_memory) * 0.80 / 256) * 256)

if rmm_mode == "cuda":
Expand All @@ -55,13 +53,13 @@ def install():
mr = rmm.mr.ManagedMemoryResource()
rmm.mr.set_current_device_resource(mr)
elif rmm_mode == "managed_pool":
rmm.reinitialize(
managed_memory=True,
pool_allocator=True,
mr = rmm.mr.PoolMemoryResource(
rmm.mr.ManagedMemoryResource(),
initial_pool_size=free_memory,
)
rmm.mr.set_current_device_resource(mr)
else:
raise TypeError(f"Unsupported rmm mode: {rmm_mode}")
raise ValueError(f"Unsupported rmm mode: {rmm_mode}")


def pytest_load_initial_conftests(early_config, parser, args):
Expand Down

0 comments on commit 3f8f214

Please sign in to comment.