Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Datasets API Update: Add Extra Params and Improve Testing #2453

Merged
merged 8 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions python/cugraph/cugraph/experimental/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,20 @@
small_tree = Dataset(meta_path / "small_tree.yaml")


# LARGE DATASETS
LARGE_DATASETS = [cyber]
MEDIUM_DATASETS = [polbooks]

# <10,000 lines
MEDIUM_DATASETS = [netscience, polbooks]
SMALL_DATASETS = [karate, dolphins, netscience]

# <500 lines
SMALL_DATASETS = [karate, small_line, small_tree, dolphins]
RLY_SMALL_DATASETS = [small_line, small_tree]

# ALL
ALL_DATASETS = [karate, dolphins, netscience, polbooks, cyber,
small_line, small_tree]
ALL_DATASETS = [karate, dolphins, netscience, polbooks,
small_line, small_tree]

ALL_DATASETS_WGT = [dolphins, netscience, polbooks,
small_line, small_tree]

TEST_GROUP = [dolphins, netscience]

DATASETS_KTRUSS = [polbooks]

DATASETS_UNDIRECTED = [karate, dolphins]
38 changes: 30 additions & 8 deletions python/cugraph/cugraph/experimental/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import cugraph
import cudf
import yaml
import os
from pathlib import Path
from cugraph.structure.graph_classes import Graph


class DefaultDownloadDir:
Expand Down Expand Up @@ -64,7 +64,6 @@ class Dataset:
The metadata file for the specific graph dataset, which includes
information on the name, type, url link, data loading format, graph
properties

"""
def __init__(self, meta_data_file_name):
with open(meta_data_file_name, 'r') as file:
Expand Down Expand Up @@ -118,22 +117,45 @@ def get_edgelist(self, fetch=False):

return self._edgelist

def get_graph(self, fetch=False):
def get_graph(self, fetch=False, create_using=Graph, ignore_weights=False):
"""
Return a Graph object.

Parameters
----------
fetch : Boolean (default=False)
Automatically fetch for the dataset from the 'url' location within
the YAML file.
Downloads the dataset from the web.

create_using: cugraph.Graph (instance or class), optional
(default=Graph)
Specify the type of Graph to create. Can pass in an instance to
create a Graph instance with specified 'directed' attribute.

ignore_weights : Boolean (default=False)
Weights will be present by default, unless weights are not
included in the dataset.
rlratzel marked this conversation as resolved.
Show resolved Hide resolved
"""
if self._edgelist is None:
self.get_edgelist(fetch)

self._graph = cugraph.Graph(directed=self.metadata['is_directed'])
self._graph.from_cudf_edgelist(self._edgelist, source='src',
destination='dst')
if create_using is None:
self._graph = Graph()
elif isinstance(create_using, Graph):
attrs = {"directed": create_using.is_directed()}
self._graph = type(create_using)(**attrs)
elif type(create_using) is type(Graph):
rlratzel marked this conversation as resolved.
Show resolved Hide resolved
self._graph = create_using()
else:
raise TypeError("create_using must be a cugraph.Graph "
"(or subclass) type or instance, got: "
f"{type(create_using)}")

if (len(self.metadata['col_names']) > 2 and not(ignore_weights)):
self._graph.from_cudf_edgelist(self._edgelist, source='src',
destination='dst', edge_attr='wgt')
else:
self._graph.from_cudf_edgelist(self._edgelist, source='src',
destination='dst')

return self._graph

Expand Down
41 changes: 32 additions & 9 deletions python/cugraph/cugraph/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
import os
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from cugraph.experimental.datasets import (ALL_DATASETS)
from cugraph.experimental.datasets import (ALL_DATASETS, ALL_DATASETS_WGT,
SMALL_DATASETS)
from cugraph.structure import Graph


# =============================================================================
# Pytest Setup / Teardown - called for each test function
# =============================================================================

dataset_path = Path(__file__).parents[4] / "datasets"


# Use this to simulate a fresh API import
@pytest.fixture
def datasets():
Expand Down Expand Up @@ -125,25 +130,19 @@ def test_fetch(dataset, datasets):

@pytest.mark.parametrize("dataset", ALL_DATASETS)
def test_get_edgelist(dataset, datasets):
tmpd = TemporaryDirectory()
datasets.set_download_dir(tmpd.name)
datasets.set_download_dir(dataset_path)
E = dataset.get_edgelist(fetch=True)

assert E is not None

tmpd.cleanup()


@pytest.mark.parametrize("dataset", ALL_DATASETS)
def test_get_graph(dataset, datasets):
tmpd = TemporaryDirectory()
datasets.set_download_dir(tmpd.name)
datasets.set_download_dir(dataset_path)
G = dataset.get_graph(fetch=True)

assert G is not None

tmpd.cleanup()


@pytest.mark.parametrize("dataset", ALL_DATASETS)
def test_metadata(dataset):
Expand All @@ -167,3 +166,27 @@ def test_get_path(dataset, datasets):
def test_get_path_raises(dataset):
with pytest.raises(RuntimeError):
dataset.get_path()


@pytest.mark.parametrize("dataset", ALL_DATASETS_WGT)
def test_weights(dataset, datasets):
datasets.set_download_dir(dataset_path)

G_w = dataset.get_graph(fetch=True)
G = dataset.get_graph(fetch=True, ignore_weights=True)

assert G_w.is_weighted()
assert not G.is_weighted()


@pytest.mark.parametrize("dataset", SMALL_DATASETS)
def test_create_using(dataset, datasets):
datasets.set_download_dir(dataset_path)

G_d = dataset.get_graph()
G_t = dataset.get_graph(create_using=Graph)
G = dataset.get_graph(create_using=Graph(directed=True))

assert not G_d.is_directed()
assert not G_t.is_directed()
assert G.is_directed()