Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use TrackingResourceAdaptor to get better debug info #1079

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dask_cuda/benchmarks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def run(client: Client, args: Namespace, config: Config):
args.rmm_pool_size,
args.disable_rmm_pool,
args.rmm_log_directory,
args.enable_rmm_statistics,
)
address_to_index, results, message_data = gather_bench_results(client, args, config)
p2p_bw = peer_to_peer_bandwidths(message_data, address_to_index)
Expand Down
16 changes: 15 additions & 1 deletion dask_cuda/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[]
help="Directory to write worker and scheduler RMM log files to. "
"Logging is only enabled if RMM memory pool is enabled.",
)
cluster_args.add_argument(
"--enable-rmm-statistics",
action="store_true",
help="Use RMM's StatisticsResourceAdaptor to gather allocation statistics",
)
pentschev marked this conversation as resolved.
Show resolved Hide resolved
cluster_args.add_argument(
"--enable-tcp-over-ucx",
default=None,
Expand Down Expand Up @@ -340,6 +345,7 @@ def setup_memory_pool(
pool_size=None,
disable_pool=False,
log_directory=None,
statistics=False,
):
import cupy

Expand All @@ -358,16 +364,23 @@ def setup_memory_pool(
log_file_name=get_rmm_log_file_name(dask_worker, logging, log_directory),
)
cupy.cuda.set_allocator(rmm.rmm_cupy_allocator)
if statistics:
rmm.mr.set_current_device_resource(
rmm.mr.StatisticsResourceAdaptor(rmm.mr.get_current_device_resource())
)


def setup_memory_pools(client, is_gpu, pool_size, disable_pool, log_directory):
def setup_memory_pools(
client, is_gpu, pool_size, disable_pool, log_directory, statistics
):
if not is_gpu:
return
client.run(
setup_memory_pool,
pool_size=pool_size,
disable_pool=disable_pool,
log_directory=log_directory,
statistics=statistics,
)
# Create an RMM pool on the scheduler due to occasional deserialization
# of CUDA objects. May cause issues with InfiniBand otherwise.
Expand All @@ -376,6 +389,7 @@ def setup_memory_pools(client, is_gpu, pool_size, disable_pool, log_directory):
pool_size=1e9,
disable_pool=disable_pool,
log_directory=log_directory,
statistics=statistics,
)


Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class LocalCUDACluster(LocalCluster):
Managed memory is currently incompatible with NVLink. Trying to enable both
will result in an exception.
rmm_async: bool, default False
Initialize each worker withh RMM and set it to use RMM's asynchronous allocator.
Initialize each worker with RMM and set it to use RMM's asynchronous allocator.
See ``rmm.mr.CudaAsyncMemoryResource`` for more info.

.. warning::
Expand Down
46 changes: 41 additions & 5 deletions dask_cuda/proxify_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,38 @@
T = TypeVar("T")


def get_rmm_device_memory_usage() -> Optional[int]:
pentschev marked this conversation as resolved.
Show resolved Hide resolved
"""Get current bytes allocated on current device through RMM

Check the current RMM resource stack for resources such as
StatisticsResourceAdaptor and TrackingResourceAdaptor that
can report the current allocated bytes. Returns None, if
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
StatisticsResourceAdaptor and TrackingResourceAdaptor that
can report the current allocated bytes. Returns None, if
``StatisticsResourceAdaptor`` and ``TrackingResourceAdaptor`` that
can report the current allocated bytes. Returns ``None``, if

Copy link
Member Author

@madsbk madsbk Jan 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should only use a single ` for function, variables etc.
https://numpydoc.readthedocs.io/en/latest/format.html#common-rest-concepts

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I'm not 100% confident I know what's the difference from a single ` to double ``, IIRC the change to use double was started by @charlesbluca when he was working on RTD, could you remind us why that change was made?

In any case, I won't block this PR for this right now, if needed be we can address this on a follow-up PR.

no such resources exist.

Return
------
nbytes: int or None
Number of bytes allocated on device through RMM or None
"""

def get_rmm_memory_resource_stack(mr) -> list:
if hasattr(mr, "upstream_mr"):
return [mr] + get_rmm_memory_resource_stack(mr.upstream_mr)
return [mr]

try:
import rmm
except ImportError:
return None

for mr in get_rmm_memory_resource_stack(rmm.mr.get_current_device_resource()):
if isinstance(mr, rmm.mr.TrackingResourceAdaptor):
return mr.get_allocated_bytes()
if isinstance(mr, rmm.mr.StatisticsResourceAdaptor):
return mr.allocation_counts["current_bytes"]
return None


class Proxies(abc.ABC):
"""Abstract base class to implement tracking of proxies

Expand Down Expand Up @@ -591,12 +623,16 @@ def oom(nbytes: int) -> bool:
traceback.print_stack(file=f)
f.seek(0)
tb = f.read()

dev_mem = get_rmm_device_memory_usage()
dev_msg = ""
if dev_mem is not None:
dev_msg = f"RMM allocs: {format_bytes(dev_mem)}, "

self.logger.warning(
"RMM allocation of %s failed, spill-on-demand couldn't "
"find any device memory to spill:\n%s\ntraceback:\n%s\n",
format_bytes(nbytes),
self.manager.pprint(),
tb,
f"RMM allocation of {format_bytes(nbytes)} failed, "
"spill-on-demand couldn't find any device memory to "
f"spill.\n{dev_msg}{self.manager}, traceback:\n{tb}\n"
)
# Since we didn't find anything to spill, we give up.
return False
Expand Down
29 changes: 18 additions & 11 deletions dask_cuda/tests/test_proxify_host_file.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import re
from typing import Iterable
from unittest.mock import patch

Expand All @@ -10,6 +9,7 @@
import dask.dataframe
from dask.dataframe.shuffle import shuffle_group
from dask.sizeof import sizeof
from dask.utils import format_bytes
from distributed import Client
from distributed.utils_test import gen_test
from distributed.worker import get_worker
Expand Down Expand Up @@ -448,25 +448,32 @@ def test_on_demand_debug_info():
if not hasattr(rmm.mr, "FailureCallbackResourceAdaptor"):
pytest.skip("RMM doesn't implement FailureCallbackResourceAdaptor")

total_mem = get_device_total_memory()
rmm_pool_size = 2**20

def task():
rmm.DeviceBuffer(size=total_mem + 1)
(
rmm.DeviceBuffer(size=rmm_pool_size // 2),
rmm.DeviceBuffer(size=rmm_pool_size // 2),
rmm.DeviceBuffer(size=rmm_pool_size), # Trigger OOM
)

with dask_cuda.LocalCUDACluster(n_workers=1, jit_unspill=True) as cluster:
with dask_cuda.LocalCUDACluster(
n_workers=1,
jit_unspill=True,
rmm_pool_size=rmm_pool_size,
rmm_maximum_pool_size=rmm_pool_size,
rmm_track_allocations=True,
) as cluster:
with Client(cluster) as client:
# Warmup, which trigger the initialization of spill on demand
client.submit(range, 10).result()

# Submit too large RMM buffer
with pytest.raises(
MemoryError, match=r".*std::bad_alloc:.*CUDA error at:.*"
):
with pytest.raises(MemoryError, match="Maximum pool size exceeded"):
client.submit(task).result()

log = str(client.get_worker_logs())
assert re.search(
"WARNING - RMM allocation of .* failed, spill-on-demand", log
)
assert re.search("<ProxyManager dev_limit=.* host_limit=.*>: Empty", log)
size = format_bytes(rmm_pool_size)
assert f"WARNING - RMM allocation of {size} failed" in log
assert f"RMM allocs: {size}" in log
assert "traceback:" in log