diff --git a/benchmarks/cugraph-dgl/pytest-based/bench_cugraph_dgl_uniform_neighbor_sample.py b/benchmarks/cugraph-dgl/pytest-based/bench_cugraph_dgl_uniform_neighbor_sample.py index f05c4364840..eeee163b0af 100644 --- a/benchmarks/cugraph-dgl/pytest-based/bench_cugraph_dgl_uniform_neighbor_sample.py +++ b/benchmarks/cugraph-dgl/pytest-based/bench_cugraph_dgl_uniform_neighbor_sample.py @@ -39,7 +39,7 @@ def create_graph(graph_data): """ Create a graph instance based on the data to be loaded/generated. - """ + """ print("Initalize Pool on client") rmm.reinitialize(pool_allocator=True) # Assume strings are names of datasets in the datasets package @@ -77,7 +77,7 @@ def create_graph(graph_data): num_nodes_dict = {'_N':num_nodes} gs = CuGraphStorage(num_nodes_dict=num_nodes_dict, single_gpu=True) - gs.add_edge_data(edgelist_df, + gs.add_edge_data(edgelist_df, # reverse to make same graph as cugraph node_col_names=['dst', 'src'], canonical_etype=['_N', 'connects', '_N']) @@ -90,11 +90,9 @@ def create_mg_graph(graph_data): """ Create a graph instance based on the data to be loaded/generated. """ - ## Reserving GPU 0 for client(trainer/service project) - n_devices = os.getenv('DASK_NUM_WORKERS', 4) - n_devices = int(n_devices) + # range starts at 1 to let let 0 be used by benchmark/client process + visible_devices = os.getenv("DASK_WORKER_DEVICES", "1,2,3,4") - visible_devices = ','.join([str(i) for i in range(1, n_devices+1)]) cluster = LocalCUDACluster(protocol='ucx', rmm_pool_size='25GB', CUDA_VISIBLE_DEVICES=visible_devices) client = Client(cluster) Comms.initialize(p2p=True) @@ -137,7 +135,7 @@ def create_mg_graph(graph_data): num_nodes_dict = {'_N':num_nodes} gs = CuGraphStorage(num_nodes_dict=num_nodes_dict, single_gpu=False) - gs.add_edge_data(edgelist_df, + gs.add_edge_data(edgelist_df, node_col_names=['dst', 'src'], canonical_etype=['_N', 'C', '_N']) return (gs, client, cluster) @@ -166,7 +164,7 @@ def get_uniform_neighbor_sample_args( num_start_verts = int(num_verts * 0.25) else: num_start_verts = batch_size - + srcs = G.graphstore.gdata.get_edge_data()['_SRC_'] start_list = srcs.head(num_start_verts) assert len(start_list) == num_start_verts @@ -229,7 +227,7 @@ def bench_cugraph_dgl_uniform_neighbor_sample( fanout_val.reverse() sampler = dgl.dataloading.NeighborSampler(uns_args["fanout"]) sampler_f = sampler.sample_blocks - + # Warmup _ = sampler_f(g=G, seed_nodes=uns_args["seed_nodes"]) # print(f"\n{uns_args}") diff --git a/benchmarks/cugraph/pytest-based/bench_algos.py b/benchmarks/cugraph/pytest-based/bench_algos.py index a1fefaf237c..d7fcb7812e4 100644 --- a/benchmarks/cugraph/pytest-based/bench_algos.py +++ b/benchmarks/cugraph/pytest-based/bench_algos.py @@ -12,7 +12,7 @@ # limitations under the License. import pytest - +import numpy as np import pytest_benchmark # FIXME: Remove this when rapids_pytest_benchmark.gpubenchmark is available # everywhere @@ -29,12 +29,16 @@ def setFixtureParamNames(*args, **kwargs): pass +import rmm +import dask_cudf +from pylibcugraph.testing import gen_fixture_params_product + import cugraph +import cugraph.dask as dask_cugraph from cugraph.structure.number_map import NumberMap -from cugraph.testing import utils -from pylibcugraph.testing import gen_fixture_params_product +from cugraph.generators import rmat +from cugraph.testing import utils, mg_utils from cugraph.utilities.utils import is_device_version_less_than -import rmm from cugraph_benchmarking.params import ( directed_datasets, @@ -43,53 +47,122 @@ def setFixtureParamNames(*args, **kwargs): pool_allocator, ) -fixture_params = gen_fixture_params_product( - (directed_datasets + undirected_datasets, "ds"), +# duck-type compatible Dataset for RMAT data +class RmatDataset: + def __init__(self, scale=4, edgefactor=2, mg=False): + self._scale = scale + self._edgefactor = edgefactor + self._edgelist = None + + self.mg = mg + + def __str__(self): + mg_str = "mg" if self.mg else "sg" + return f"rmat_{mg_str}_{self._scale}_{self._edgefactor}" + + def get_edgelist(self, fetch=False): + seed = 42 + if self._edgelist is None: + self._edgelist = rmat( + self._scale, + (2**self._scale)*self._edgefactor, + 0.57, # from Graph500 + 0.19, # from Graph500 + 0.19, # from Graph500 + seed or 42, + clip_and_flip=False, + scramble_vertex_ids=True, + create_using=None, # return edgelist instead of Graph instance + mg=self.mg + ) + rng = np.random.default_rng(seed) + if self.mg: + self._edgelist["weight"] = self._edgelist.map_partitions( + lambda df: rng.random(size=len(df))) + else: + self._edgelist["weight"] = rng.random(size=len(self._edgelist)) + + return self._edgelist + + def get_graph(self, + fetch=False, + create_using=cugraph.Graph, + ignore_weights=False, + store_transposed=False): + if isinstance(create_using, cugraph.Graph): + # what about BFS if trnaposed is True + attrs = {"directed": create_using.is_directed()} + G = type(create_using)(**attrs) + elif type(create_using) is type: + G = create_using() + + edge_attr = None if ignore_weights else "weight" + df = self.get_edgelist() + if isinstance(df, dask_cudf.DataFrame): + G.from_dask_cudf_edgelist(df, + source="src", + destination="dst", + edge_attr=edge_attr, + store_transposed=store_transposed) + else: + G.from_cudf_edgelist(df, + source="src", + destination="dst", + edge_attr=edge_attr, + store_transposed=store_transposed) + return G + + def get_path(self): + """ + (this is likely not needed for use with pytest-benchmark, just added for + API completeness with Dataset.) + """ + return str(self) + + def unload(self): + self._edgelist = None + + +_rmat_scale = getattr(pytest, "_rmat_scale", 20) # ~1M vertices +_rmat_edgefactor = getattr(pytest, "_rmat_edgefactor", 16) # ~17M edges +rmat_sg_dataset = pytest.param(RmatDataset(scale=_rmat_scale, + edgefactor=_rmat_edgefactor, + mg=False), + marks=[pytest.mark.rmat_data, + pytest.mark.sg, + ]) +rmat_mg_dataset = pytest.param(RmatDataset(scale=_rmat_scale, + edgefactor=_rmat_edgefactor, + mg=True), + marks=[pytest.mark.rmat_data, + pytest.mark.mg, + ]) + +rmm_fixture_params = gen_fixture_params_product( (managed_memory, "mm"), (pool_allocator, "pa")) - -############################################################################### -# Helpers -def createGraph(csvFileName, graphType=None): - """ - Helper function to create a Graph (directed or undirected) based on - csvFileName. - """ - if graphType is None: - # There's potential value in verifying that a directed graph can be - # created from a undirected dataset, and an undirected from a directed - # dataset. (For now?) do not include those combinations to keep - # benchmark runtime and complexity lower, and assume tests have - # coverage to verify correctness for those combinations. - if "directed" in csvFileName.parts: - graphType = cugraph.Graph(directed=True) - else: - graphType = cugraph.Graph() - - gdf = utils.read_csv_file(csvFileName) - if len(gdf.columns) == 2: - edge_attr = None - else: - edge_attr = "2" - - return cugraph.from_cudf_edgelist( - gdf, - source="0", destination="1", edge_attr=edge_attr, - create_using=graphType, - renumber=True) - +dataset_fixture_params = gen_fixture_params_product( + (directed_datasets + + undirected_datasets + + [rmat_sg_dataset, rmat_mg_dataset], "ds")) # Record the current RMM settings so reinitialize() will be called only when a # change is needed (RMM defaults both values to False). The --allow-rmm-reinit # option is required to allow the RMM options to be set by the pytest user # directly, in order to prevent reinitialize() from being called more than once # (see conftest.py for details). +# The defaults for managed_mem (False) and pool_alloc (True) are set in +# conftest.py RMM_SETTINGS = {"managed_mem": False, "pool_alloc": False} - +# FIXME: this only changes the RMM config in a SG environment. The dask config +# that applies to RMM in an MG environment is not changed by this! def reinitRMM(managed_mem, pool_alloc): - + """ + Reinitializes RMM to the value of managed_mem and pool_alloc, but only if + those values are different that the current configuration. + """ if (managed_mem != RMM_SETTINGS["managed_mem"]) or \ (pool_alloc != RMM_SETTINGS["pool_alloc"]): @@ -111,79 +184,86 @@ def reinitRMM(managed_mem, pool_alloc): # # For benchmarks, the operations performed in fixtures are not measured as part # of the benchmark. + @pytest.fixture(scope="module", - params=fixture_params) -def edgelistCreated(request): - """ - Returns a new edgelist created from a CSV, which is specified as part of - the parameterization for this fixture. - """ + params=rmm_fixture_params) +def rmm_config(request): # Since parameterized fixtures do not assign param names to param values, # manually call the helper to do so. Ensure the order of the name list # passed to it matches if there are >1 params. # If the request only contains n params, only the first n names are set. - setFixtureParamNames(request, ["dataset", "managed_mem", "pool_allocator"]) - - csvFileName = request.param[0] - reinitRMM(request.param[1], request.param[2]) - return utils.read_csv_file(csvFileName) + setFixtureParamNames(request, ["managed_mem", "pool_allocator"]) + reinitRMM(request.param[0], request.param[1]) @pytest.fixture(scope="module", - params=fixture_params) -def graphWithAdjListComputed(request): + params=dataset_fixture_params) +def dataset(request, rmm_config): + """ - Create a Graph obj from the CSV file in param, compute the adjacency list - and return it. + Fixture which provides a Dataset instance, setting up a Dask cluster and + client if necessary for MG, to tests and other fixtures. When all + tests/fixtures are done with the Dataset, it has the Dask cluster and + client torn down (if MG) and all data loaded is freed. """ - setFixtureParamNames(request, ["dataset", "managed_mem", "pool_allocator"]) - csvFileName = request.param[0] - reinitRMM(request.param[1], request.param[2]) + setFixtureParamNames(request, ["dataset"]) + dataset = request.param[0] + client = cluster = None + # For now, only RmatDataset instanaces support MG and have a "mg" attr. + if hasattr(dataset, "mg") and dataset.mg: + (client, cluster) = mg_utils.start_dask_client() - G = createGraph(csvFileName, cugraph.structure.graph_classes.Graph) - G.view_adj_list() + yield dataset + + dataset.unload() + if client is not None: + mg_utils.stop_dask_client(client, cluster) + + +@pytest.fixture(scope="module") +def edgelist(request, dataset): + df = dataset.get_edgelist() + return df + + +@pytest.fixture(scope="module") +def graph(request, dataset): + G = dataset.get_graph() return G -@pytest.fixture(scope="module", - params=fixture_params) -def anyGraphWithAdjListComputed(request): - """ - Create a Graph (directed or undirected) obj based on the param, compute the - adjacency list and return it. - """ - setFixtureParamNames(request, ["dataset", "managed_mem", "pool_allocator"]) - csvFileName = request.param[0] - reinitRMM(request.param[1], request.param[2]) +@pytest.fixture(scope="module") +def unweighted_graph(request, dataset): + G = dataset.get_graph(ignore_weights=True) + return G - G = createGraph(csvFileName) - G.view_adj_list() + +@pytest.fixture(scope="module") +def directed_graph(request, dataset): + G = dataset.get_graph(create_using=cugraph.Graph(directed=True)) return G -@pytest.fixture(scope="module", - params=fixture_params) -def anyGraphWithTransposedAdjListComputed(request): +@pytest.fixture(scope="module") +def transposed_graph(request, dataset): + G = dataset.get_graph(store_transposed=True) + return G + + +############################################################################### +def is_graph_distributed(graph): """ - Create a Graph (directed or undirected) obj based on the param, compute the - transposed adjacency list and return it. + Return True if graph is distributed (for use with cugraph.dask APIs) """ - setFixtureParamNames(request, ["dataset", "managed_mem", "pool_allocator"]) - csvFileName = request.param[0] - reinitRMM(request.param[1], request.param[2]) - - G = createGraph(csvFileName) - G.view_transposed_adj_list() - return G + return isinstance(graph.edgelist.edgelist_df, dask_cudf.DataFrame) ############################################################################### # Benchmarks -@pytest.mark.ETL -def bench_create_graph(gpubenchmark, edgelistCreated): +def bench_create_graph(gpubenchmark, edgelist): gpubenchmark(cugraph.from_cudf_edgelist, - edgelistCreated, - source="0", destination="1", + edgelist, + source="src", destination="dst", create_using=cugraph.structure.graph_classes.Graph, renumber=False) @@ -191,99 +271,142 @@ def bench_create_graph(gpubenchmark, edgelistCreated): # Creating directed Graphs on small datasets runs in micro-seconds, which # results in thousands of rounds before the default threshold is met, so lower # the max_time for this benchmark. -@pytest.mark.ETL @pytest.mark.benchmark( warmup=True, warmup_iterations=10, max_time=0.005 ) -def bench_create_digraph(gpubenchmark, edgelistCreated): +def bench_create_digraph(gpubenchmark, edgelist): gpubenchmark(cugraph.from_cudf_edgelist, - edgelistCreated, - source="0", destination="1", + edgelist, + source="src", destination="dst", create_using=cugraph.Graph(directed=True), renumber=False) -@pytest.mark.ETL -def bench_renumber(gpubenchmark, edgelistCreated): - gpubenchmark(NumberMap.renumber, edgelistCreated, "0", "1") +def bench_renumber(gpubenchmark, edgelist): + gpubenchmark(NumberMap.renumber, edgelist, "src", "dst") -def bench_pagerank(gpubenchmark, anyGraphWithTransposedAdjListComputed): - gpubenchmark(cugraph.pagerank, anyGraphWithTransposedAdjListComputed) +def bench_pagerank(gpubenchmark, transposed_graph): + pagerank = dask_cugraph.pagerank if is_graph_distributed(transposed_graph) \ + else cugraph.pagerank + gpubenchmark(pagerank, transposed_graph) -def bench_bfs(gpubenchmark, anyGraphWithAdjListComputed): - start = anyGraphWithAdjListComputed.edgelist.edgelist_df["src"][0] - gpubenchmark(cugraph.bfs, anyGraphWithAdjListComputed, start) +def bench_bfs(gpubenchmark, graph): + bfs = dask_cugraph.bfs if is_graph_distributed(graph) else cugraph.bfs + start = graph.edgelist.edgelist_df["src"][0] + gpubenchmark(bfs, graph, start) -def bench_force_atlas2(gpubenchmark, anyGraphWithAdjListComputed): - gpubenchmark(cugraph.force_atlas2, anyGraphWithAdjListComputed, - max_iter=50) +def bench_force_atlas2(gpubenchmark, graph): + if is_graph_distributed(graph): + pytest.skip("distributed graphs are not supported") + gpubenchmark(cugraph.force_atlas2, graph, max_iter=50) -def bench_sssp(gpubenchmark, anyGraphWithAdjListComputed): - start = anyGraphWithAdjListComputed.edgelist.edgelist_df["src"][0] - gpubenchmark(cugraph.sssp, anyGraphWithAdjListComputed, start) +def bench_sssp(gpubenchmark, graph): + sssp = dask_cugraph.sssp if is_graph_distributed(graph) else cugraph.sssp + start = graph.edgelist.edgelist_df["src"][0] + gpubenchmark(sssp, graph, start) -def bench_jaccard(gpubenchmark, graphWithAdjListComputed): - gpubenchmark(cugraph.jaccard, graphWithAdjListComputed) +def bench_jaccard(gpubenchmark, unweighted_graph): + G = unweighted_graph + jaccard = dask_cugraph.jaccard if is_graph_distributed(G) else cugraph.jaccard + gpubenchmark(jaccard, G) @pytest.mark.skipif( is_device_version_less_than((7, 0)), reason="Not supported on Pascal") -def bench_louvain(gpubenchmark, graphWithAdjListComputed): - gpubenchmark(cugraph.louvain, graphWithAdjListComputed) +def bench_louvain(gpubenchmark, graph): + louvain = dask_cugraph.louvain if is_graph_distributed(graph) else cugraph.louvain + gpubenchmark(louvain, graph) -def bench_weakly_connected_components(gpubenchmark, - anyGraphWithAdjListComputed): - if anyGraphWithAdjListComputed.is_directed(): - G = anyGraphWithAdjListComputed.to_undirected() +def bench_weakly_connected_components(gpubenchmark, graph): + if is_graph_distributed(graph): + pytest.skip("distributed graphs are not supported") + if graph.is_directed(): + G = graph.to_undirected() else: - G = anyGraphWithAdjListComputed + G = graph gpubenchmark(cugraph.weakly_connected_components, G) -def bench_overlap(gpubenchmark, anyGraphWithAdjListComputed): - gpubenchmark(cugraph.overlap, anyGraphWithAdjListComputed) +def bench_overlap(gpubenchmark, unweighted_graph): + G = unweighted_graph + overlap = dask_cugraph.overlap if is_graph_distributed(G) else cugraph.overlap + gpubenchmark(overlap, G) -def bench_triangle_count(gpubenchmark, graphWithAdjListComputed): - gpubenchmark(cugraph.triangle_count, graphWithAdjListComputed) +def bench_triangle_count(gpubenchmark, graph): + tc = dask_cugraph.triangle_count if is_graph_distributed(graph) \ + else cugraph.triangle_count + gpubenchmark(tc, graph) -def bench_spectralBalancedCutClustering(gpubenchmark, - graphWithAdjListComputed): - gpubenchmark(cugraph.spectralBalancedCutClustering, - graphWithAdjListComputed, 2) +def bench_spectralBalancedCutClustering(gpubenchmark, graph): + if is_graph_distributed(graph): + pytest.skip("distributed graphs are not supported") + gpubenchmark(cugraph.spectralBalancedCutClustering, graph, 2) @pytest.mark.skip(reason="Need to guarantee graph has weights, " "not doing that yet") -def bench_spectralModularityMaximizationClustering( - gpubenchmark, anyGraphWithAdjListComputed): - gpubenchmark(cugraph.spectralModularityMaximizationClustering, - anyGraphWithAdjListComputed, 2) +def bench_spectralModularityMaximizationClustering(gpubenchmark, graph): + smmc = dask_cugraph.spectralModularityMaximizationClustering \ + if is_graph_distributed(graph) \ + else cugraph.spectralModularityMaximizationClustering + gpubenchmark(smmc, graph, 2) + + +def bench_graph_degree(gpubenchmark, graph): + gpubenchmark(graph.degree) + + +def bench_graph_degrees(gpubenchmark, graph): + if is_graph_distributed(graph): + pytest.skip("distributed graphs are not supported") + gpubenchmark(graph.degrees) + + +def bench_betweenness_centrality(gpubenchmark, graph): + bc = dask_cugraph.betweenness_centrality if is_graph_distributed(graph) \ + else cugraph.betweenness_centrality + gpubenchmark(bc, graph, k=10, random_state=123) + +def bench_edge_betweenness_centrality(gpubenchmark, graph): + if is_graph_distributed(graph): + pytest.skip("distributed graphs are not supported") + gpubenchmark(cugraph.edge_betweenness_centrality, graph, k=10, seed=123) -def bench_graph_degree(gpubenchmark, anyGraphWithAdjListComputed): - gpubenchmark(anyGraphWithAdjListComputed.degree) +def bench_uniform_neighbor_sample(gpubenchmark, graph): + uns = dask_cugraph.uniform_neighbor_sample if is_graph_distributed(graph) \ + else cugraph.uniform_neighbor_sample -def bench_graph_degrees(gpubenchmark, anyGraphWithAdjListComputed): - gpubenchmark(anyGraphWithAdjListComputed.degrees) + seed = 42 + # FIXME: may need to provide number_of_vertices separately + num_verts_in_graph = graph.number_of_vertices() + len_start_list = max(int(num_verts_in_graph * 0.01), 2) + srcs = graph.edgelist.edgelist_df["src"] + frac = len_start_list / num_verts_in_graph + start_list = srcs.sample(frac=frac, random_state=seed) + # Attempt to automatically handle a dask Series + if hasattr(start_list, "compute"): + start_list = start_list.compute() -def bench_betweenness_centrality(gpubenchmark, anyGraphWithAdjListComputed): - gpubenchmark(cugraph.betweenness_centrality, - anyGraphWithAdjListComputed, k=10, random_state=123) + fanout_vals = [5, 5, 5] + gpubenchmark(uns, graph, start_list=start_list, fanout_vals=fanout_vals) -def bench_edge_betweenness_centrality(gpubenchmark, - anyGraphWithAdjListComputed): - gpubenchmark(cugraph.edge_betweenness_centrality, - anyGraphWithAdjListComputed, k=10, seed=123) +def bench_egonet(gpubenchmark, graph): + egonet = dask_cugraph.ego_graph if is_graph_distributed(graph) \ + else cugraph.ego_graph + n = 1 + radius = 2 + gpubenchmark(egonet, graph, n, radius=radius) diff --git a/benchmarks/cugraph/pytest-based/bench_cugraph_uniform_neighbor_sample.py b/benchmarks/cugraph/pytest-based/bench_cugraph_uniform_neighbor_sample.py index 8fe6e81ccf1..157c64b0b20 100644 --- a/benchmarks/cugraph/pytest-based/bench_cugraph_uniform_neighbor_sample.py +++ b/benchmarks/cugraph/pytest-based/bench_cugraph_uniform_neighbor_sample.py @@ -107,10 +107,8 @@ def create_mg_graph(graph_data): Create a graph instance based on the data to be loaded/generated, return a tuple containing (graph_obj, num_verts, client, cluster) """ - n_devices = os.getenv("DASK_NUM_WORKERS", 4) - n_devices = int(n_devices) # range starts at 1 to let let 0 be used by benchmark/client process - visible_devices = ",".join([str(i) for i in range(1, n_devices+1)]) + visible_devices = os.getenv("DASK_WORKER_DEVICES", "1,2,3,4") (client, cluster) = start_dask_client( # enable_tcp_over_ucx=True, diff --git a/benchmarks/cugraph/pytest-based/conftest.py b/benchmarks/cugraph/pytest-based/conftest.py index 435941098de..fd029471869 100644 --- a/benchmarks/cugraph/pytest-based/conftest.py +++ b/benchmarks/cugraph/pytest-based/conftest.py @@ -11,6 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + + def pytest_addoption(parser): parser.addoption("--allow-rmm-reinit", action="store_true", @@ -21,6 +24,22 @@ def pytest_addoption(parser): "since it does not represent a typical use case, and support for " "this may be limited. Instead, consider multiple pytest runs that " "use a fixed set of RMM settings.") + parser.addoption("--rmat-scale", + action="store", + type=int, + default=20, + metavar="scale", + help="For use when using synthetic graph data generated using RMAT. " + "This results in a graph with 2^scale vertices. Default is " + "%(default)s.") + parser.addoption("--rmat-edgefactor", + action="store", + type=int, + default=16, + metavar="edgefactor", + help="For use when using synthetic graph data generated using RMAT. " + "This results in a graph with (2^scale)*edgefactor edges. Default " + "is %(default)s.") def pytest_sessionstart(session): @@ -28,6 +47,13 @@ def pytest_sessionstart(session): # "mark expression" (-m) the markers for no managedmem and # poolallocator. This will result in the RMM reinit() function to be called # only once in the running process (the typical use case). + # + # FIXME: consider making the RMM config options set using a CLI option + # instead of by markers. This would mean only one RMM config can be used + # per test session, which could eliminate problems related to calling RMM + # reinit multiple times in the same process. This would not be a major + # change to the benchmark UX since the user is discouraged from doing a + # reinit multiple times anyway (hence the --allow-rmm-reinit flag). if session.config.getoption("allow_rmm_reinit") is False: currentMarkexpr = session.config.getoption("markexpr") @@ -41,3 +67,10 @@ def pytest_sessionstart(session): newMarkexpr = f"({currentMarkexpr}) and ({newMarkexpr})" session.config.option.markexpr = newMarkexpr + + # Set the value of the CLI options for RMAT here since any RmatDataset + # objects must be instantiated prior to running test fixtures in order to + # have their test ID generated properly. + # FIXME: is there a better way to do this? + pytest._rmat_scale = session.config.getoption("rmat_scale") + pytest._rmat_edgefactor = session.config.getoption("rmat_edgefactor") diff --git a/benchmarks/pytest.ini b/benchmarks/pytest.ini index b61fa92d403..6af3aab27fe 100644 --- a/benchmarks/pytest.ini +++ b/benchmarks/pytest.ini @@ -14,7 +14,6 @@ markers = managedmem_off: RMM managed memory disabled poolallocator_on: RMM pool allocator enabled poolallocator_off: RMM pool allocator disabled - ETL: benchmarks for ETL steps small: small datasets tiny: tiny datasets directed: directed datasets @@ -50,6 +49,8 @@ markers = num_clients_32: start 32 cugraph-service clients fanout_10_25: fanout [10, 25] for sampling algos fanout_5_10_15: fanout [5, 10, 15] for sampling algos + rmat_data: RMAT-generated synthetic datasets + file_data: datasets from $RAPIDS_DATASET_ROOT_DIR python_classes = Bench* diff --git a/benchmarks/shared/python/cugraph_benchmarking/params.py b/benchmarks/shared/python/cugraph_benchmarking/params.py index 4cf749d0c21..ee63b8768a6 100644 --- a/benchmarks/shared/python/cugraph_benchmarking/params.py +++ b/benchmarks/shared/python/cugraph_benchmarking/params.py @@ -11,32 +11,68 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path - import pytest -from cugraph.testing import utils from pylibcugraph.testing.utils import gen_fixture_params +from cugraph.testing import RAPIDS_DATASET_ROOT_DIR_PATH +from cugraph.experimental.datasets import ( + Dataset, + karate, +) +# Create Dataset objects from .csv files. +# Once the cugraph.dataset package is updated to include the metadata files for +# these (like karate), these will no longer need to be explicitly instantiated. +hollywood = Dataset( + csv_file=RAPIDS_DATASET_ROOT_DIR_PATH / "csv/undirected/hollywood.csv", + csv_col_names=["src", "dst"], + csv_col_dtypes=["int32", "int32"]) +europe_osm = Dataset( + csv_file=RAPIDS_DATASET_ROOT_DIR_PATH / "csv/undirected/europe_osm.csv", + csv_col_names=["src", "dst"], + csv_col_dtypes=["int32", "int32"]) +cit_patents = Dataset( + csv_file=RAPIDS_DATASET_ROOT_DIR_PATH / "csv/directed/cit-Patents.csv", + csv_col_names=["src", "dst"], + csv_col_dtypes=["int32", "int32"]) +soc_livejournal = Dataset( + csv_file=RAPIDS_DATASET_ROOT_DIR_PATH / "csv/directed/soc-LiveJournal1.csv", + csv_col_names=["src", "dst"], + csv_col_dtypes=["int32", "int32"]) -# FIXME: omitting soc-twitter-2010.csv due to OOM error on some workstations. +# Assume all "file_data" (.csv file on disk) datasets are too small to be useful for MG. undirected_datasets = [ - pytest.param(Path(utils.RAPIDS_DATASET_ROOT_DIR) / "karate.csv", - marks=[pytest.mark.tiny, pytest.mark.undirected]), - pytest.param(Path(utils.RAPIDS_DATASET_ROOT_DIR) / "csv/undirected/hollywood.csv", - marks=[pytest.mark.small, pytest.mark.undirected]), - pytest.param(Path(utils.RAPIDS_DATASET_ROOT_DIR) / "csv/undirected/europe_osm.csv", - marks=[pytest.mark.undirected]), - # pytest.param("../datasets/csv/undirected/soc-twitter-2010.csv", - # marks=[pytest.mark.undirected]), + pytest.param(karate, + marks=[pytest.mark.tiny, + pytest.mark.undirected, + pytest.mark.file_data, + pytest.mark.sg, + ]), + pytest.param(hollywood, + marks=[pytest.mark.small, + pytest.mark.undirected, + pytest.mark.file_data, + pytest.mark.sg, + ]), + pytest.param(europe_osm, + marks=[pytest.mark.undirected, + pytest.mark.file_data, + pytest.mark.sg, + ]), ] directed_datasets = [ - pytest.param(Path(utils.RAPIDS_DATASET_ROOT_DIR) / "csv/directed/cit-Patents.csv", - marks=[pytest.mark.small, pytest.mark.directed]), - pytest.param(Path( - utils.RAPIDS_DATASET_ROOT_DIR) / "csv/directed/soc-LiveJournal1.csv", - marks=[pytest.mark.directed]), + pytest.param(cit_patents, + marks=[pytest.mark.small, + pytest.mark.directed, + pytest.mark.file_data, + pytest.mark.sg, + ]), + pytest.param(soc_livejournal, + marks=[pytest.mark.directed, + pytest.mark.file_data, + pytest.mark.sg, + ]), ] managed_memory = [ diff --git a/ci/test_cpp.sh b/ci/test_cpp.sh index a6c4cdb4a4f..f02ac748f18 100755 --- a/ci/test_cpp.sh +++ b/ci/test_cpp.sh @@ -34,7 +34,7 @@ nvidia-smi # RAPIDS_DATASET_ROOT_DIR is used by test scripts export RAPIDS_DATASET_ROOT_DIR="$(realpath datasets)" pushd "${RAPIDS_DATASET_ROOT_DIR}" -./get_test_data.sh +./get_test_data.sh --subset popd EXITCODE=0 diff --git a/ci/test_python.sh b/ci/test_python.sh index 2ee340f6f5c..191e08e57f9 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -43,7 +43,7 @@ nvidia-smi # RAPIDS_DATASET_ROOT_DIR is used by test scripts export RAPIDS_DATASET_ROOT_DIR="$(realpath datasets)" pushd "${RAPIDS_DATASET_ROOT_DIR}" -./get_test_data.sh +./get_test_data.sh --benchmark popd EXITCODE=0 diff --git a/python/cugraph-dgl/tests/conftest.py b/python/cugraph-dgl/tests/conftest.py index c1f4841a905..dc6b7db9b45 100644 --- a/python/cugraph-dgl/tests/conftest.py +++ b/python/cugraph-dgl/tests/conftest.py @@ -11,30 +11,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import pytest -from dask.distributed import Client -from cugraph.dask.comms import comms as Comms -from cugraph.testing.mg_utils import stop_dask_client, start_dask_client +from cugraph.testing.mg_utils import ( + start_dask_client, + stop_dask_client, +) @pytest.fixture(scope="module") def dask_client(): - dask_scheduler_file = os.environ.get("SCHEDULER_FILE") - - if dask_scheduler_file is not None: - dask_client = Client(scheduler_file=dask_scheduler_file) - dask_cluster = None - else: - dask_client, dask_cluster = start_dask_client( - dask_worker_devices="0", protocol="tcp" - ) - - if not Comms.is_initialized(): - Comms.initialize(p2p=True) + # start_dask_client will check for the SCHEDULER_FILE and + # DASK_WORKER_DEVICES env vars and use them when creating a client if + # set. start_dask_client will also initialize the Comms singleton. + dask_client, dask_cluster = start_dask_client( + dask_worker_devices="0", protocol="tcp" + ) yield dask_client stop_dask_client(dask_client, dask_cluster) - print("\ndask_client fixture: client.close() called") diff --git a/python/cugraph/cugraph/dask/common/mg_utils.py b/python/cugraph/cugraph/dask/common/mg_utils.py index 5ab884a5b34..6acda48c9da 100644 --- a/python/cugraph/cugraph/dask/common/mg_utils.py +++ b/python/cugraph/cugraph/dask/common/mg_utils.py @@ -13,11 +13,8 @@ import os -import rmm import numba.cuda -from dask_cuda import LocalCUDACluster -from dask.distributed import Client # FIXME: this raft import breaks the library if ucx-py is # not available. They are necessary only when doing MG work. @@ -32,11 +29,6 @@ default_client = MissingUCXPy() else: raise -# FIXME: cugraph/__init__.py also imports the comms module, but -# depending on the import environment, cugraph/comms/__init__.py -# may be imported instead. The following imports the comms.py -# module directly -from cugraph.dask.comms import comms as Comms # FIXME: We currently look for the default client from dask, as such is the @@ -76,42 +68,3 @@ def get_visible_devices(): else: visible_devices = _visible_devices.strip().split(",") return visible_devices - - -def setup_local_dask_cluster(p2p=True): - """ - Performs steps to setup a Dask cluster using LocalCUDACluster and returns - the LocalCUDACluster and corresponding client instance. - """ - cluster = LocalCUDACluster() - client = Client(cluster) - client.wait_for_workers(len(get_visible_devices())) - Comms.initialize(p2p=p2p) - - return (cluster, client) - - -def teardown_local_dask_cluster(cluster, client): - """ - Performs steps to destroy a Dask cluster and a corresponding client - instance. - """ - Comms.destroy() - client.close() - cluster.close() - - -def start_dask_client(): - n_devices = os.getenv("DASK_NUM_WORKERS", 2) - n_devices = int(n_devices) - - visible_devices = ",".join([str(i) for i in range(1, n_devices + 1)]) - - cluster = LocalCUDACluster( - protocol="ucx", rmm_pool_size="25GB", CUDA_VISIBLE_DEVICES=visible_devices - ) - client = Client(cluster) - Comms.initialize(p2p=True) - rmm.reinitialize(pool_allocator=True) - - return cluster, client diff --git a/python/cugraph/cugraph/experimental/datasets/__init__.py b/python/cugraph/cugraph/experimental/datasets/__init__.py index d12248c99ff..a1dd45b3d9f 100644 --- a/python/cugraph/cugraph/experimental/datasets/__init__.py +++ b/python/cugraph/cugraph/experimental/datasets/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -15,7 +15,6 @@ from cugraph.experimental.datasets.dataset import ( Dataset, load_all, - set_config, set_download_dir, get_download_dir, default_download_dir, diff --git a/python/cugraph/cugraph/experimental/datasets/dataset.py b/python/cugraph/cugraph/experimental/datasets/dataset.py index 36e6de487c0..6b395d50fef 100644 --- a/python/cugraph/cugraph/experimental/datasets/dataset.py +++ b/python/cugraph/cugraph/experimental/datasets/dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -70,31 +70,94 @@ class Dataset: properties """ - def __init__(self, meta_data_file_name): - with open(meta_data_file_name, "r") as file: - self.metadata = yaml.safe_load(file) - + def __init__( + self, + metadata_yaml_file=None, + csv_file=None, + csv_header=None, + csv_delim=" ", + csv_col_names=None, + csv_col_dtypes=None, + ): + self._metadata_file = None self._dl_path = default_download_dir self._edgelist = None - self._graph = None self._path = None + + if metadata_yaml_file is not None and csv_file is not None: + raise ValueError("cannot specify both metadata_yaml_file and csv_file") + + elif metadata_yaml_file is not None: + with open(metadata_yaml_file, "r") as file: + self.metadata = yaml.safe_load(file) + self._metadata_file = Path(metadata_yaml_file) + + elif csv_file is not None: + if csv_col_names is None or csv_col_dtypes is None: + raise ValueError( + "csv_col_names and csv_col_dtypes must both be " + "not None when csv_file is specified." + ) + self._path = Path(csv_file) + if self._path.exists() is False: + raise FileNotFoundError(csv_file) + self.metadata = { + "name": self._path.with_suffix("").name, + "file_type": ".csv", + "url": None, + "header": csv_header, + "delim": csv_delim, + "col_names": csv_col_names, + "col_types": csv_col_dtypes, + } + + else: + raise ValueError("must specify either metadata_yaml_file or csv_file") + + def __str__(self): """ - self._path = self._dl_path.path / (self.metadata['name'] + - self.metadata['file_type']) + Use the basename of the meta_data_file the instance was constructed with, + without any extension, as the string repr. """ + # The metadata file is likely to have a more descriptive file name, so + # use that one first if present. + # FIXME: this may need to provide a more unique or descriptive string repr + if self._metadata_file is not None: + return self._metadata_file.with_suffix("").name + else: + return self.get_path().with_suffix("").name def __download_csv(self, url): + """ + Downloads the .csv file from url to the current download path + (self._dl_path), updates self._path with the full path to the + downloaded file, and returns the latest value of self._path. + """ self._dl_path.path.mkdir(parents=True, exist_ok=True) filename = self.metadata["name"] + self.metadata["file_type"] if self._dl_path.path.is_dir(): df = cudf.read_csv(url) - df.to_csv(self._dl_path.path / filename, index=False) + self._path = self._dl_path.path / filename + df.to_csv(self._path, index=False) else: raise RuntimeError( f"The directory {self._dl_path.path.absolute()}" "does not exist" ) + return self._path + + def unload(self): + + """ + Remove all saved internal objects, forcing them to be re-created when + accessed. + + NOTE: This will cause calls to get_*() to re-read the dataset file from + disk. The caller should ensure the file on disk has not moved/been + deleted/changed. + """ + self._edgelist = None def get_edgelist(self, fetch=False): """ @@ -106,12 +169,11 @@ def get_edgelist(self, fetch=False): Automatically fetch for the dataset from the 'url' location within the YAML file. """ - if self._edgelist is None: full_path = self.get_path() if not full_path.is_file(): if fetch: - self.__download_csv(self.metadata["url"]) + full_path = self.__download_csv(self.metadata["url"]) else: raise RuntimeError( f"The datafile {full_path} does not" @@ -131,7 +193,13 @@ def get_edgelist(self, fetch=False): return self._edgelist - def get_graph(self, fetch=False, create_using=Graph, ignore_weights=False): + def get_graph( + self, + fetch=False, + create_using=Graph, + ignore_weights=False, + store_transposed=False, + ): """ Return a Graph object. @@ -156,13 +224,13 @@ def get_graph(self, fetch=False, create_using=Graph, ignore_weights=False): self.get_edgelist(fetch) if create_using is None: - self._graph = Graph() + G = Graph() elif isinstance(create_using, Graph): # what about BFS if trnaposed is True attrs = {"directed": create_using.is_directed()} - self._graph = type(create_using)(**attrs) + G = type(create_using)(**attrs) elif type(create_using) is type: - self._graph = create_using() + G = create_using() else: raise TypeError( "create_using must be a cugraph.Graph " @@ -171,23 +239,30 @@ def get_graph(self, fetch=False, create_using=Graph, ignore_weights=False): ) if len(self.metadata["col_names"]) > 2 and not (ignore_weights): - self._graph.from_cudf_edgelist( - self._edgelist, source="src", destination="dst", edge_attr="wgt" + G.from_cudf_edgelist( + self._edgelist, + source="src", + destination="dst", + edge_attr="wgt", + store_transposed=store_transposed, ) else: - self._graph.from_cudf_edgelist( - self._edgelist, source="src", destination="dst" + G.from_cudf_edgelist( + self._edgelist, + source="src", + destination="dst", + store_transposed=store_transposed, ) - - return self._graph + return G def get_path(self): """ Returns the location of the stored dataset file """ - self._path = self._dl_path.path / ( - self.metadata["name"] + self.metadata["file_type"] - ) + if self._path is None: + self._path = self._dl_path.path / ( + self.metadata["name"] + self.metadata["file_type"] + ) return self._path.absolute() @@ -218,20 +293,6 @@ def load_all(force=False): df.to_csv(save_to, index=False) -def set_config(cfgpath): - """ - Read in a custom config file. - - Parameters - ---------- - cfgfile : String - Read the custom config file given its path, and override the default - """ - with open(Path(cfgpath), "r") as file: - cfg = yaml.safe_load(file) - default_download_dir.path = Path(cfg["download_dir"]) - - def set_download_dir(path): """ Set the download directory for fetching datasets diff --git a/python/cugraph/cugraph/testing/__init__.py b/python/cugraph/cugraph/testing/__init__.py index e69de29bb2d..db1c574de21 100644 --- a/python/cugraph/cugraph/testing/__init__.py +++ b/python/cugraph/cugraph/testing/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cugraph.testing.utils import ( + RAPIDS_DATASET_ROOT_DIR_PATH, +) diff --git a/python/cugraph/cugraph/testing/mg_utils.py b/python/cugraph/cugraph/testing/mg_utils.py index 82dc2751ccf..1e1a481e4d6 100644 --- a/python/cugraph/cugraph/testing/mg_utils.py +++ b/python/cugraph/cugraph/testing/mg_utils.py @@ -34,7 +34,72 @@ def start_dask_client( jit_unspill=False, device_memory_limit=0.8, ): + """ + Creates a new dask client, and possibly also a cluster, and returns them as + a tuple (client, cluster). + + If the env var SCHEDULER_FILE is set, it is assumed to contain the path to + a JSON file generated by a running dask scheduler that can be used to + configure the new dask client (the new client object returned will be a + client to that scheduler), and the value of cluster will be None. If + SCHEDULER_FILE is not set, a new LocalCUDACluster will be created and + returned as the value of cluster. + + If the env var DASK_WORKER_DEVICES is set, it will be assumed to be a list + of comma-separated GPU devices (ex. "0,1,2" for those 3 devices) for the + LocalCUDACluster to use when setting up individual workers (1 worker per + device). If not set, the parameter dask_worker_devices will be used the + same way instead. If neither are set, the new LocalCUDACluster instance + will default to one worker per device visible to this process. + + If the env var DASK_LOCAL_DIRECTORY is set, it will be used as the + "local_directory" arg to LocalCUDACluster, for all temp files generated. + + Upon successful creation of a client (either to a LocalCUDACluster or + otherwise), the cugraph.dask.comms.comms singleton is initialized using + "p2p=True". + + Parameters + ---------- + protocol : str or None, default None + The "protocol" arg to LocalCUDACluster (ex. "tcp"), see docs for + dask_cuda.LocalCUDACluster for details. This parameter is ignored if + the env var SCHEDULER_FILE is set which implies the dask cluster has + already been created. + + rmm_pool_size : int, str or None, default None + The "rmm_pool_size" arg to LocalCUDACluster (ex. "20GB"), see docs for + dask_cuda.LocalCUDACluster for details. This parameter is ignored if + the env var SCHEDULER_FILE is set which implies the dask cluster has + already been created. + + dask_worker_devices : str, list of int, or None, default None + GPUs to restrict activity to. Can be a string (like ``"0,1,2,3"``), + list (like ``[0, 1, 2, 3]``), or ``None`` to use all available GPUs. + This parameter is overridden by the value of env var + DASK_WORKER_DEVICES. This parameter is ignored if the env var + SCHEDULER_FILE is set which implies the dask cluster has already been + created. + + jit_unspill : bool or None, default None + The "jit_unspill" arg to LocalCUDACluster to enable just-in-time + spilling, see docs for dask_cuda.LocalCUDACluster for details. This + parameter is ignored if the env var SCHEDULER_FILE is set which implies + the dask cluster has already been created. + + device_memory_limit : int, float, str, or None, default 0.8 + The "device_memory_limit" arg to LocalCUDACluster to determine when + workers start spilling to host memory, see docs for + dask_cuda.LocalCUDACluster for details. This parameter is ignored if + the env var SCHEDULER_FILE is set which implies the dask cluster has + already been created. + """ dask_scheduler_file = os.environ.get("SCHEDULER_FILE") + dask_local_directory = os.getenv("DASK_LOCAL_DIRECTORY") + # Allow the DASK_WORKER_DEVICES env var to override a value passed in. If + # neither are set, this will be None. + dask_worker_devices = os.getenv("DASK_WORKER_DEVICES", dask_worker_devices) + cluster = None client = None tempdir_object = None @@ -55,16 +120,21 @@ def start_dask_client( f"WARNING: {dask_worker_devices=} is ignored in start_dask_client() " "when using dask SCHEDULER_FILE" ) - initialize() client = Client(scheduler_file=dask_scheduler_file) - print("\ndask_client created using " f"{dask_scheduler_file}") + # FIXME: use proper logging, INFO or DEBUG level + print("\nDask client created using " f"{dask_scheduler_file}") else: - # The tempdir created by tempdir_object should be cleaned up once - # tempdir_object goes out-of-scope and is deleted. - tempdir_object = tempfile.TemporaryDirectory() + if dask_local_directory is None: + # The tempdir created by tempdir_object should be cleaned up once + # tempdir_object is deleted. + tempdir_object = tempfile.TemporaryDirectory() + local_directory = tempdir_object.name + else: + local_directory = dask_local_directory + cluster = LocalCUDACluster( - local_directory=tempdir_object.name, + local_directory=local_directory, protocol=protocol, rmm_pool_size=rmm_pool_size, CUDA_VISIBLE_DEVICES=dask_worker_devices, @@ -72,8 +142,23 @@ def start_dask_client( device_memory_limit=device_memory_limit, ) client = Client(cluster) - client.wait_for_workers(len(get_visible_devices())) - print("\ndask_client created using LocalCUDACluster") + + if dask_worker_devices is None: + num_workers = len(get_visible_devices()) + else: + if isinstance(dask_worker_devices, list): + num_workers = len(dask_worker_devices) + else: + # FIXME: this assumes a properly formatted string with commas + num_workers = len(dask_worker_devices.split(",")) + + client.wait_for_workers(num_workers) + # Add a reference to tempdir_object to the client to prevent it from + # being deleted when this function returns. This will be deleted in + # stop_dask_client() + client.tempdir_object = tempdir_object + # FIXME: use proper logging, INFO or DEBUG level + print("\nDask client/cluster created using LocalCUDACluster") Comms.initialize(p2p=True) @@ -81,11 +166,21 @@ def start_dask_client( def stop_dask_client(client, cluster=None): + """ + Shutdown/cleanup a client and possibly cluster object returned from + start_dask_client(). This also stops the cugraph.dask.comms.comms + singleton. + """ Comms.destroy() client.close() if cluster: cluster.close() - print("\ndask_client closed.") + # Remove a TemporaryDirectory object that may have been assigned to the + # client, which should remove it and all the contents from disk. + if hasattr(client, "tempdir_object"): + del client.tempdir_object + # FIXME: use proper logging, INFO or DEBUG level + print("\nDask client closed.") def restart_client(client): diff --git a/python/cugraph/cugraph/testing/utils.py b/python/cugraph/cugraph/testing/utils.py index 46bc8d99e83..0dae17ed14e 100644 --- a/python/cugraph/cugraph/testing/utils.py +++ b/python/cugraph/cugraph/testing/utils.py @@ -131,7 +131,6 @@ def read_csv_for_nx(csv_file, read_weights_in_sp=True, read_weights=True): - print("Reading " + str(csv_file) + "...") if read_weights: if read_weights_in_sp is True: df = pd.read_csv( diff --git a/python/cugraph/cugraph/tests/conftest.py b/python/cugraph/cugraph/tests/conftest.py index 388a90d4e98..fece006c4b8 100644 --- a/python/cugraph/cugraph/tests/conftest.py +++ b/python/cugraph/cugraph/tests/conftest.py @@ -12,9 +12,9 @@ # limitations under the License. import pytest -from cugraph.dask.common.mg_utils import ( - setup_local_dask_cluster, - teardown_local_dask_cluster, +from cugraph.testing.mg_utils import ( + start_dask_client, + stop_dask_client, ) # module-wide fixtures @@ -34,7 +34,11 @@ def gpubenchmark(): @pytest.fixture(scope="module") def dask_client(): - cluster, client = setup_local_dask_cluster() - yield client + # start_dask_client will check for the SCHEDULER_FILE and + # DASK_WORKER_DEVICES env vars and use them when creating a client if + # set. start_dask_client will also initialize the Comms singleton. + dask_client, dask_cluster = start_dask_client() - teardown_local_dask_cluster(cluster, client) + yield dask_client + + stop_dask_client(dask_client, dask_cluster) diff --git a/python/cugraph/cugraph/tests/generators/test_rmat_mg.py b/python/cugraph/cugraph/tests/generators/test_rmat_mg.py index 22403a189b8..d5d6db4d70f 100644 --- a/python/cugraph/cugraph/tests/generators/test_rmat_mg.py +++ b/python/cugraph/cugraph/tests/generators/test_rmat_mg.py @@ -16,10 +16,12 @@ import dask_cudf +from cugraph.testing.mg_utils import ( + start_dask_client, + stop_dask_client, +) from cugraph.dask.common.mg_utils import ( is_single_gpu, - setup_local_dask_cluster, - teardown_local_dask_cluster, ) from cugraph.generators import rmat import cugraph @@ -61,13 +63,13 @@ def setup_module(): global _client global _visible_devices if not _is_single_gpu: - (_cluster, _client) = setup_local_dask_cluster(p2p=True) + (_client, _cluster) = start_dask_client() _visible_devices = _client.scheduler_info()["workers"] def teardown_module(): if not _is_single_gpu: - teardown_local_dask_cluster(_cluster, _client) + stop_dask_client(_client, _cluster) ############################################################################### diff --git a/python/cugraph/cugraph/tests/sampling/test_uniform_neighbor_sample_mg.py b/python/cugraph/cugraph/tests/sampling/test_uniform_neighbor_sample_mg.py index 779eabd1ecf..dc70e7153d7 100644 --- a/python/cugraph/cugraph/tests/sampling/test_uniform_neighbor_sample_mg.py +++ b/python/cugraph/cugraph/tests/sampling/test_uniform_neighbor_sample_mg.py @@ -478,7 +478,8 @@ def test_uniform_neighbor_sample_edge_properties_self_loops(dask_client): @pytest.mark.mg @pytest.mark.parametrize("with_replacement", [True, False]) @pytest.mark.skipif( - int(os.getenv("DASK_NUM_WORKERS", 2)) < 2, reason="too few workers to test" + len(os.getenv("DASK_WORKER_DEVICES", "0").split(",")) < 2, + reason="too few workers to test", ) def test_uniform_neighbor_edge_properties_sample_small_start_list( dask_client, with_replacement diff --git a/python/cugraph/cugraph/tests/utils/test_dataset.py b/python/cugraph/cugraph/tests/utils/test_dataset.py index 6a145833c7f..e72de2ecf8a 100644 --- a/python/cugraph/cugraph/tests/utils/test_dataset.py +++ b/python/cugraph/cugraph/tests/utils/test_dataset.py @@ -12,60 +12,71 @@ # limitations under the License. -import pytest -import yaml import os from pathlib import Path -from tempfile import NamedTemporaryFile, TemporaryDirectory -from cugraph.experimental.datasets import ALL_DATASETS, ALL_DATASETS_WGT, SMALL_DATASETS -from cugraph.structure import Graph +from tempfile import TemporaryDirectory +import gc +import pytest -# ============================================================================= -# Pytest Setup / Teardown - called for each test function -# ============================================================================= +from cugraph.structure import Graph +from cugraph.testing import RAPIDS_DATASET_ROOT_DIR_PATH +from cugraph.experimental.datasets import ( + ALL_DATASETS, + ALL_DATASETS_WGT, + SMALL_DATASETS, +) +from cugraph.experimental import datasets -dataset_path = Path(__file__).parents[4] / "datasets" +# Add the sg marker to all tests in this module. +pytestmark = pytest.mark.sg -# Use this to simulate a fresh API import -@pytest.fixture -def datasets(): - from cugraph.experimental import datasets +############################################################################### +# Fixtures - yield datasets - del datasets - clear_locals() +# module fixture - called once for this module +@pytest.fixture(scope="module") +def tmpdir(): + """ + Create a tmp dir for downloads, etc., run a test, then cleanup when the + test is done. + """ + tmpd = TemporaryDirectory() + yield tmpd + # teardown + tmpd.cleanup() -def clear_locals(): +# function fixture - called once for each function in this module +@pytest.fixture(scope="function", autouse=True) +def setup(tmpdir): + """ + Fixture used for individual test setup and teardown. This ensures each + Dataset object starts with the same state and cleans up when the test is + done. + """ + # FIXME: this relies on dataset features (unload) which themselves are + # being tested in this module. for dataset in ALL_DATASETS: - dataset._edgelist = None - dataset._graph = None - dataset._path = None + dataset.unload() + gc.collect() + datasets.set_download_dir(tmpdir.name) -# We use this to create tempfiles that act as config files when we call -# set_config(). Arguments passed will act as custom download directories -def create_config(custom_path="custom_storage_location"): - config_yaml = """ - fetch: False - force: False - download_dir: None - """ - c = yaml.safe_load(config_yaml) - c["download_dir"] = custom_path + yield - outfile = NamedTemporaryFile() - with open(outfile.name, "w") as f: - yaml.dump(c, f, sort_keys=False) + # teardown + for dataset in ALL_DATASETS: + dataset.unload() + gc.collect() - return outfile +############################################################################### +# Tests # setting download_dir to None effectively re-initialized the default -@pytest.mark.sg -def test_env_var(datasets): +def test_env_var(): os.environ["RAPIDS_DATASET_ROOT_DIR"] = "custom_storage_location" datasets.set_download_dir(None) @@ -75,26 +86,14 @@ def test_env_var(datasets): del os.environ["RAPIDS_DATASET_ROOT_DIR"] -@pytest.mark.sg -def test_home_dir(datasets): +def test_home_dir(): datasets.set_download_dir(None) expected_path = Path.home() / ".cugraph/datasets" assert datasets.get_download_dir() == expected_path -@pytest.mark.sg -def test_set_config(datasets): - cfg = create_config() - datasets.set_config(cfg.name) - - assert datasets.get_download_dir() == Path("custom_storage_location").absolute() - - cfg.close() - - -@pytest.mark.sg -def test_set_download_dir(datasets): +def test_set_download_dir(): tmpd = TemporaryDirectory() datasets.set_download_dir(tmpd.name) @@ -103,59 +102,26 @@ def test_set_download_dir(datasets): tmpd.cleanup() -@pytest.mark.sg -@pytest.mark.skip( - reason="Timeout errors; see: https://github.com/rapidsai/cugraph/issues/2810" -) -def test_load_all(datasets): - tmpd = TemporaryDirectory() - cfg = create_config(custom_path=tmpd.name) - datasets.set_config(cfg.name) - datasets.load_all() - - for data in datasets.ALL_DATASETS: - file_path = Path(tmpd.name) / ( - data.metadata["name"] + data.metadata["file_type"] - ) - assert file_path.is_file() - - tmpd.cleanup() - - -@pytest.mark.sg @pytest.mark.parametrize("dataset", ALL_DATASETS) -def test_fetch(dataset, datasets): - tmpd = TemporaryDirectory() - cfg = create_config(custom_path=tmpd.name) - datasets.set_config(cfg.name) - +def test_fetch(dataset): E = dataset.get_edgelist(fetch=True) assert E is not None assert dataset.get_path().is_file() - tmpd.cleanup() - -@pytest.mark.sg @pytest.mark.parametrize("dataset", ALL_DATASETS) -def test_get_edgelist(dataset, datasets): - datasets.set_download_dir(dataset_path) +def test_get_edgelist(dataset): E = dataset.get_edgelist(fetch=True) - assert E is not None -@pytest.mark.sg @pytest.mark.parametrize("dataset", ALL_DATASETS) -def test_get_graph(dataset, datasets): - datasets.set_download_dir(dataset_path) +def test_get_graph(dataset): G = dataset.get_graph(fetch=True) - assert G is not None -@pytest.mark.sg @pytest.mark.parametrize("dataset", ALL_DATASETS) def test_metadata(dataset): M = dataset.metadata @@ -163,9 +129,8 @@ def test_metadata(dataset): assert M is not None -@pytest.mark.sg @pytest.mark.parametrize("dataset", ALL_DATASETS) -def test_get_path(dataset, datasets): +def test_get_path(dataset): tmpd = TemporaryDirectory() datasets.set_download_dir(tmpd.name) dataset.get_edgelist(fetch=True) @@ -174,27 +139,103 @@ def test_get_path(dataset, datasets): tmpd.cleanup() -@pytest.mark.sg @pytest.mark.parametrize("dataset", ALL_DATASETS_WGT) -def test_weights(dataset, datasets): - datasets.set_download_dir(dataset_path) - - G_w = dataset.get_graph(fetch=True) +def test_weights(dataset): + G = dataset.get_graph(fetch=True) + assert G.is_weighted() G = dataset.get_graph(fetch=True, ignore_weights=True) - - assert G_w.is_weighted() assert not G.is_weighted() -@pytest.mark.sg @pytest.mark.parametrize("dataset", SMALL_DATASETS) -def test_create_using(dataset, datasets): - datasets.set_download_dir(dataset_path) +def test_create_using(dataset): + G = dataset.get_graph(fetch=True) + assert not G.is_directed() + G = dataset.get_graph(fetch=True, create_using=Graph) + assert not G.is_directed() + G = dataset.get_graph(fetch=True, create_using=Graph(directed=True)) + assert G.is_directed() - G_d = dataset.get_graph() - G_t = dataset.get_graph(create_using=Graph) - G = dataset.get_graph(create_using=Graph(directed=True)) - assert not G_d.is_directed() - assert not G_t.is_directed() - assert G.is_directed() +def test_ctor_with_datafile(): + from cugraph.experimental.datasets import karate + + karate_csv = RAPIDS_DATASET_ROOT_DIR_PATH / "karate.csv" + + # test that only a metadata file or csv can be specified, not both + with pytest.raises(ValueError): + datasets.Dataset(metadata_yaml_file="metadata_file", csv_file=karate_csv) + + # ensure at least one arg is provided + with pytest.raises(ValueError): + datasets.Dataset() + + # ensure csv file has all other required args (col names and col dtypes) + with pytest.raises(ValueError): + datasets.Dataset(csv_file=karate_csv) + + with pytest.raises(ValueError): + datasets.Dataset(csv_file=karate_csv, csv_col_names=["src", "dst", "wgt"]) + + # test with file that DNE + with pytest.raises(FileNotFoundError): + datasets.Dataset( + csv_file="/some/file/that/does/not/exist", + csv_col_names=["src", "dst", "wgt"], + csv_col_dtypes=["int32", "int32", "float32"], + ) + + expected_karate_edgelist = karate.get_edgelist(fetch=True) + + # test with file path as string, ensure fetch=True does not break + ds = datasets.Dataset( + csv_file=karate_csv.as_posix(), + csv_col_names=["src", "dst", "wgt"], + csv_col_dtypes=["int32", "int32", "float32"], + ) + # cudf.testing.testing.assert_frame_equal() would be good to use to + # compare, but for some reason it seems to be holding a reference to a + # dataframe and gc.collect() does not free everything + el = ds.get_edgelist() + assert len(el) == len(expected_karate_edgelist) + assert str(ds) == "karate" + assert ds.get_path() == karate_csv + + # test with file path as Path object + ds = datasets.Dataset( + csv_file=karate_csv, + csv_col_names=["src", "dst", "wgt"], + csv_col_dtypes=["int32", "int32", "float32"], + ) + el = ds.get_edgelist() + assert len(el) == len(expected_karate_edgelist) + assert str(ds) == "karate" + assert ds.get_path() == karate_csv + + +def test_unload(): + email_csv = RAPIDS_DATASET_ROOT_DIR_PATH / "email-Eu-core.csv" + + ds = datasets.Dataset( + csv_file=email_csv.as_posix(), + csv_col_names=["src", "dst", "wgt"], + csv_col_dtypes=["int32", "int32", "float32"], + ) + + # FIXME: another (better?) test would be to check free memory and assert + # the memory use increases after get_*(), then returns to the pre-get_*() + # level after unload(). However, that type of test may fail for several + # reasons (the device being monitored is accidentally also being used by + # another process, and the use of memory pools to name two). Instead, just + # test that the internal members get cleared on unload(). + assert ds._edgelist is None + + ds.get_edgelist() + assert ds._edgelist is not None + ds.unload() + assert ds._edgelist is None + + ds.get_graph() + assert ds._edgelist is not None + ds.unload() + assert ds._edgelist is None diff --git a/python/pylibcugraph/pylibcugraph/testing/utils.py b/python/pylibcugraph/pylibcugraph/testing/utils.py index f578a146f4b..50fe18fe13d 100644 --- a/python/pylibcugraph/pylibcugraph/testing/utils.py +++ b/python/pylibcugraph/pylibcugraph/testing/utils.py @@ -36,17 +36,17 @@ def gen_fixture_params(*param_values): combination passed in, or a callable that accepts a list of values and returns a string. - gen_fixture_params( (pytest.param(True, marks=[pytest.mark.A_good], id="A=True"), - pytest.param(False, marks=[pytest.mark.B_bad], id="B=False")), - (pytest.param(False, marks=[pytest.mark.A_bad], id="A=False"), - pytest.param(True, marks=[pytest.mark.B_good], id="B=True")), + gen_fixture_params( (pytest.param(True, marks=[pytest.mark.A_good], id="A:True"), + pytest.param(False, marks=[pytest.mark.B_bad], id="B:False")), + (pytest.param(False, marks=[pytest.mark.A_bad], id="A:False"), + pytest.param(True, marks=[pytest.mark.B_good], id="B:True")), ) results in fixture param combinations: - True, False - marks=[A_good, B_bad] - id="A=True,B=False" - False, False - marks=[A_bad, B_bad] - id="A=False,B=True" + True, False - marks=[A_good, B_bad] - id="A:True-B:False" + False, False - marks=[A_bad, B_bad] - id="A:False-B:True" """ fixture_params = [] param_type = pytest.param().__class__ # @@ -89,10 +89,10 @@ def gen_fixture_params_product(*args): results in fixture param combinations: - True, True - marks=[A_good, B_good] - id="A=True,B=True" - True, False - marks=[A_good, B_bad] - id="A=True,B=False" - False, True - marks=[A_bad, B_good] - id="A=False,B=True" - False, False - marks=[A_bad, B_bad] - id="A=False,B=False" + True, True - marks=[A_good, B_good] - id="A:True-B:True" + True, False - marks=[A_good, B_bad] - id="A:True-B:False" + False, True - marks=[A_bad, B_good] - id="A:False-B:True" + False, False - marks=[A_bad, B_bad] - id="A:False-B:False" Simply using itertools.product on the lists would result in a list of sublists of individual param objects (ie. not "merged"), which would not be @@ -124,9 +124,9 @@ def gen_fixture_params_product(*args): for (p, paramId) in zip(paramCombo, ids): # Assume paramId is either a string or a callable if isinstance(paramId, str): - id_strings.append("%s=%s" % (paramId, p.values[0])) + id_strings.append("%s:%s" % (paramId, p.values[0])) else: id_strings.append(paramId(p.values[0])) - comboid = ",".join(id_strings) + comboid = "-".join(id_strings) retList.append(pytest.param(values, marks=marks, id=comboid)) return retList