Skip to content

Commit

Permalink
Changed code to disable k_truss on CUDA 11.4 to not use numba.cuda.ru…
Browse files Browse the repository at this point in the history
…ntime.get_version() at import time since this creates a CUDA context which breaks dask LocalCUDACluster init (causes a nccl init invaid usage exception).
  • Loading branch information
rlratzel committed Sep 9, 2021
1 parent c3d798f commit e1222e4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 49 deletions.
4 changes: 2 additions & 2 deletions python/cugraph/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from cugraph.dask.common.mg_utils import get_visible_devices


# session-wide fixtures
# module-wide fixtures

@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def dask_client():
dask_scheduler_file = os.environ.get("SCHEDULER_FILE")
cluster = None
Expand Down
39 changes: 2 additions & 37 deletions python/cugraph/cugraph/community/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,42 +23,7 @@
)
from cugraph.community.subgraph_extraction import subgraph
from cugraph.community.triangle_count import triangles
from cugraph.community.ktruss_subgraph import ktruss_subgraph
from cugraph.community.ktruss_subgraph import k_truss
from cugraph.community.egonet import ego_graph
from cugraph.community.egonet import batched_ego_graphs

# FIXME: special case for ktruss on CUDA 11.4: an 11.4 bug causes ktruss to
# crash in that environment. Allow ktruss to import on non-11.4 systems, but
# replace ktruss with a __UnsupportedModule instance, which lazily raises an
# exception when referenced.
from numba import cuda
try:
__cuda_version = cuda.runtime.get_version()
except cuda.cudadrv.runtime.CudaRuntimeAPIError:
__cuda_version = "n/a"

__ktruss_unsupported_cuda_version = (11, 4)

class __UnsupportedModule:
def __init__(self, exception):
self.__exception = exception

def __getattr__(self, attr):
raise self.__exception

def __call__(self, *args, **kwargs):
raise self.__exception


if __cuda_version != __ktruss_unsupported_cuda_version:
from cugraph.community.ktruss_subgraph import ktruss_subgraph
from cugraph.community.ktruss_subgraph import k_truss
else:
__kuvs = ".".join([str(n) for n in __ktruss_unsupported_cuda_version])
k_truss = __UnsupportedModule(
NotImplementedError("k_truss is not currently supported in CUDA"
f" {__kuvs} environments.")
)
ktruss_subgraph = __UnsupportedModule(
NotImplementedError("ktruss_subgraph is not currently supported in CUDA"
f" {__kuvs} environments.")
)
27 changes: 17 additions & 10 deletions python/cugraph/cugraph/community/ktruss_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@
from cugraph.utilities import check_nx_graph
from cugraph.utilities import cugraph_to_nx

from numba import cuda


# FIXME: special case for ktruss on CUDA 11.4: an 11.4 bug causes ktruss to
# crash in that environment. Allow ktruss to import on non-11.4 systems, but
# raise an exception if ktruss is directly imported on 11.4.
from numba import cuda
try:
__cuda_version = cuda.runtime.get_version()
except cuda.cudadrv.runtime.CudaRuntimeAPIError:
__cuda_version = "n/a"
def _ensure_compatible_cuda_version():
try:
cuda_version = cuda.runtime.get_version()
except cuda.cudadrv.runtime.CudaRuntimeAPIError:
cuda_version = "n/a"

__ktruss_unsupported_cuda_version = (11, 4)
unsupported_cuda_version = (11, 4)

if __cuda_version == __ktruss_unsupported_cuda_version:
__kuvs = ".".join([str(n) for n in __ktruss_unsupported_cuda_version])
raise NotImplementedError("k_truss is not currently supported in CUDA"
f" {__kuvs} environments.")
if cuda_version == unsupported_cuda_version:
ver_string = ".".join([str(n) for n in unsupported_cuda_version])
raise NotImplementedError("k_truss is not currently supported in CUDA"
f" {ver_string} environments.")


def k_truss(G, k):
Expand Down Expand Up @@ -62,6 +65,8 @@ def k_truss(G, k):
The networkx graph will NOT have all attributes copied over
"""

_ensure_compatible_cuda_version()

G, isNx = check_nx_graph(G)

if isNx is True:
Expand Down Expand Up @@ -137,6 +142,8 @@ def ktruss_subgraph(G, k, use_weights=True):
>>> k_subgraph = cugraph.ktruss_subgraph(G, 3)
"""

_ensure_compatible_cuda_version()

KTrussSubgraph = Graph()
if type(G) is not Graph:
raise Exception("input graph must be undirected")
Expand Down

0 comments on commit e1222e4

Please sign in to comment.