Skip to content

Commit

Permalink
Add support for UCXX (#1268)
Browse files Browse the repository at this point in the history
Add support for UCXX via support for `protocol="ucxx"`. Extend existing UCX-Py tests to test both UCX-Py and UCXX now.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Mads R. B. Kristensen (https://github.com/madsbk)
  - Ray Douglass (https://github.com/raydouglass)

URL: #1268
  • Loading branch information
pentschev authored Oct 31, 2023
1 parent d9e1001 commit 004185e
Show file tree
Hide file tree
Showing 18 changed files with 248 additions and 86 deletions.
2 changes: 1 addition & 1 deletion ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ DASK_CUDA_WAIT_WORKERS_MIN_TIMEOUT=20 \
UCXPY_IFNAME=eth0 \
UCX_WARN_UNUSED_ENV_VARS=n \
UCX_MEMTYPE_CACHE=n \
timeout 40m pytest \
timeout 60m pytest \
-vv \
--durations=0 \
--capture=no \
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cudf_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cudf_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
print_key_value(key="Frac-match", value=f"{args.frac_match}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cudf_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
print_key_value(key="Protocol", value=f"{args.protocol}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cupy_map_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
print_key_value(key="Protocol", value=f"{args.protocol}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[]
cluster_args.add_argument(
"-p",
"--protocol",
choices=["tcp", "ucx"],
choices=["tcp", "ucx", "ucxx"],
default="tcp",
type=str,
help="The communication protocol to use.",
Expand Down
63 changes: 47 additions & 16 deletions dask_cuda/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numba.cuda

import dask
import distributed.comm.ucx
from distributed.diagnostics.nvml import get_device_index_and_uuid, has_cuda_context

from .utils import get_ucx_config
Expand All @@ -23,12 +22,21 @@ def _create_cuda_context_handler():
numba.cuda.current_context()


def _create_cuda_context():
def _create_cuda_context(protocol="ucx"):
if protocol not in ["ucx", "ucxx"]:
return
try:
# Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
# context directly from the UCX module, thus avoiding a similar warning there.
try:
distributed.comm.ucx.init_once()
if protocol == "ucx":
import distributed.comm.ucx

distributed.comm.ucx.init_once()
elif protocol == "ucxx":
import distributed_ucxx.ucxx

distributed_ucxx.ucxx.init_once()
except ModuleNotFoundError:
# UCX initialization has to be delegated to Distributed, it will take care
# of setting correct environment variables and importing `ucp` after that.
Expand All @@ -39,20 +47,35 @@ def _create_cuda_context():
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
)
ctx = has_cuda_context()
if (
ctx.has_context
and not distributed.comm.ucx.cuda_context_created.has_context
):
distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
if protocol == "ucx":
if (
ctx.has_context
and not distributed.comm.ucx.cuda_context_created.has_context
):
distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
elif protocol == "ucxx":
if (
ctx.has_context
and not distributed_ucxx.ucxx.cuda_context_created.has_context
):
distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())

_create_cuda_context_handler()

if not distributed.comm.ucx.cuda_context_created.has_context:
ctx = has_cuda_context()
if ctx.has_context and ctx.device_info != cuda_visible_device:
distributed.comm.ucx._warn_cuda_context_wrong_device(
cuda_visible_device, ctx.device_info, os.getpid()
)
if protocol == "ucx":
if not distributed.comm.ucx.cuda_context_created.has_context:
ctx = has_cuda_context()
if ctx.has_context and ctx.device_info != cuda_visible_device:
distributed.comm.ucx._warn_cuda_context_wrong_device(
cuda_visible_device, ctx.device_info, os.getpid()
)
elif protocol == "ucxx":
if not distributed_ucxx.ucxx.cuda_context_created.has_context:
ctx = has_cuda_context()
if ctx.has_context and ctx.device_info != cuda_visible_device:
distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
cuda_visible_device, ctx.device_info, os.getpid()
)

except Exception:
logger.error("Unable to start CUDA Context", exc_info=True)
Expand All @@ -64,6 +87,7 @@ def initialize(
enable_infiniband=None,
enable_nvlink=None,
enable_rdmacm=None,
protocol="ucx",
):
"""Create CUDA context and initialize UCX-Py, depending on user parameters.
Expand Down Expand Up @@ -118,7 +142,7 @@ def initialize(
dask.config.set({"distributed.comm.ucx": ucx_config})

if create_cuda_context:
_create_cuda_context()
_create_cuda_context(protocol=protocol)


@click.command()
Expand All @@ -127,6 +151,12 @@ def initialize(
default=False,
help="Create CUDA context",
)
@click.option(
"--protocol",
default=None,
type=str,
help="Communication protocol, such as: 'tcp', 'tls', 'ucx' or 'ucxx'.",
)
@click.option(
"--enable-tcp-over-ucx/--disable-tcp-over-ucx",
default=False,
Expand All @@ -150,10 +180,11 @@ def initialize(
def dask_setup(
service,
create_cuda_context,
protocol,
enable_tcp_over_ucx,
enable_infiniband,
enable_nvlink,
enable_rdmacm,
):
if create_cuda_context:
_create_cuda_context()
_create_cuda_context(protocol=protocol)
9 changes: 6 additions & 3 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,11 @@ def __init__(
if enable_tcp_over_ucx or enable_infiniband or enable_nvlink:
if protocol is None:
protocol = "ucx"
elif protocol != "ucx":
raise TypeError("Enabling InfiniBand or NVLink requires protocol='ucx'")
elif protocol not in ["ucx", "ucxx"]:
raise TypeError(
"Enabling InfiniBand or NVLink requires protocol='ucx' or "
"protocol='ucxx'"
)

self.host = kwargs.get("host", None)

Expand Down Expand Up @@ -371,7 +374,7 @@ def __init__(
) + ["dask_cuda.initialize"]
self.new_spec["options"]["preload_argv"] = self.new_spec["options"].get(
"preload_argv", []
) + ["--create-cuda-context"]
) + ["--create-cuda-context", "--protocol", protocol]

self.cuda_visible_devices = CUDA_VISIBLE_DEVICES
self.scale(n_workers)
Expand Down
42 changes: 32 additions & 10 deletions dask_cuda/tests/test_dgx.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,13 @@ def test_default():
assert not p.exitcode


def _test_tcp_over_ucx():
ucp = pytest.importorskip("ucp")
def _test_tcp_over_ucx(protocol):
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

with LocalCUDACluster(enable_tcp_over_ucx=True) as cluster:
with LocalCUDACluster(protocol=protocol, enable_tcp_over_ucx=True) as cluster:
with Client(cluster) as client:
res = da.from_array(numpy.arange(10000), chunks=(1000,))
res = res.sum().compute()
Expand All @@ -93,10 +96,17 @@ def check_ucx_options():
assert all(client.run(check_ucx_options).values())


def test_tcp_over_ucx():
ucp = pytest.importorskip("ucp") # NOQA: F841
@pytest.mark.parametrize(
"protocol",
["ucx", "ucxx"],
)
def test_tcp_over_ucx(protocol):
if protocol == "ucx":
pytest.importorskip("ucp")
elif protocol == "ucxx":
pytest.importorskip("ucxx")

p = mp.Process(target=_test_tcp_over_ucx)
p = mp.Process(target=_test_tcp_over_ucx, args=(protocol,))
p.start()
p.join()
assert not p.exitcode
Expand All @@ -117,9 +127,14 @@ def test_tcp_only():
assert not p.exitcode


def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm):
def _test_ucx_infiniband_nvlink(
protocol, enable_infiniband, enable_nvlink, enable_rdmacm
):
cupy = pytest.importorskip("cupy")
ucp = pytest.importorskip("ucp")
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

if enable_infiniband is None and enable_nvlink is None and enable_rdmacm is None:
enable_tcp_over_ucx = None
Expand All @@ -135,13 +150,15 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
cm_tls_priority = ["tcp"]

initialize(
protocol=protocol,
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_infiniband=enable_infiniband,
enable_nvlink=enable_nvlink,
enable_rdmacm=enable_rdmacm,
)

with LocalCUDACluster(
protocol=protocol,
interface="ib0",
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_infiniband=enable_infiniband,
Expand Down Expand Up @@ -171,6 +188,7 @@ def check_ucx_options():
assert all(client.run(check_ucx_options).values())


@pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
@pytest.mark.parametrize(
"params",
[
Expand All @@ -185,8 +203,11 @@ def check_ucx_options():
_get_dgx_version() == DGXVersion.DGX_A100,
reason="Automatic InfiniBand device detection Unsupported for %s" % _get_dgx_name(),
)
def test_ucx_infiniband_nvlink(params):
ucp = pytest.importorskip("ucp") # NOQA: F841
def test_ucx_infiniband_nvlink(protocol, params):
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

if params["enable_infiniband"]:
if not any([at.startswith("rc") for at in ucp.get_active_transports()]):
Expand All @@ -195,6 +216,7 @@ def test_ucx_infiniband_nvlink(params):
p = mp.Process(
target=_test_ucx_infiniband_nvlink,
args=(
protocol,
params["enable_infiniband"],
params["enable_nvlink"],
params["enable_rdmacm"],
Expand Down
8 changes: 4 additions & 4 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _test_local_cluster(protocol):
assert sum(c.run(my_rank, 0)) == sum(range(4))


@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
def test_local_cluster(protocol):
p = mp.Process(target=_test_local_cluster, args=(protocol,))
p.start()
Expand Down Expand Up @@ -160,7 +160,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):

@pytest.mark.parametrize("nworkers", [1, 2, 3])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
@pytest.mark.parametrize("_partitions", [True, False])
def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
if backend == "cudf":
Expand Down Expand Up @@ -256,7 +256,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):

@pytest.mark.parametrize("nworkers", [1, 2, 4])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
def test_dataframe_shuffle_merge(backend, protocol, nworkers):
if backend == "cudf":
pytest.importorskip("cudf")
Expand Down Expand Up @@ -293,7 +293,7 @@ def _test_jit_unspill(protocol):
assert_eq(got, expected)


@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
def test_jit_unspill(protocol):
pytest.importorskip("cudf")

Expand Down
8 changes: 6 additions & 2 deletions dask_cuda/tests/test_from_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@

from dask_cuda import LocalCUDACluster

pytest.importorskip("ucp")
cupy = pytest.importorskip("cupy")


@pytest.mark.parametrize("protocol", ["ucx", "tcp"])
@pytest.mark.parametrize("protocol", ["ucx", "ucxx", "tcp"])
def test_ucx_from_array(protocol):
if protocol == "ucx":
pytest.importorskip("ucp")
elif protocol == "ucxx":
pytest.importorskip("ucxx")

N = 10_000
with LocalCUDACluster(protocol=protocol) as cluster:
with Client(cluster):
Expand Down
Loading

0 comments on commit 004185e

Please sign in to comment.