Skip to content

Commit

Permalink
Datasets API Update: Add Extra Params and Improve Testing (rapidsai#2453
Browse files Browse the repository at this point in the history
)

Adding two new parameters to the `get_graph()` method within the datasets API.

- `default_direction` allows users to get only undirected graph objects.
- `weights` will specify whether or not the cugraph.Graph object has an `edge_attr` field.

`test_dataset.py` has been updated to:
1. Add coverage for the new parameters
2. Removing fetching datasets from the web when it's unnecessary. 

Docstrings have also been updated for clarity.

Authors:
  - Ralph Liu (https://github.com/oorliu)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)

URL: rapidsai#2453
  • Loading branch information
oorliu authored Aug 1, 2022
1 parent 93e15c0 commit 2a57740
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 29 deletions.
24 changes: 15 additions & 9 deletions python/cugraph/cugraph/experimental/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
meta_path = Path(__file__).parent / "metadata"

karate = Dataset(meta_path / "karate.yaml")
karate_data = Dataset(meta_path / "karate_data.yaml")
karate_undirected = Dataset(meta_path / "karate_undirected.yaml")
karate_asymmetric = Dataset(meta_path / "karate_asymmetric.yaml")
dolphins = Dataset(meta_path / "dolphins.yaml")
Expand All @@ -37,15 +38,20 @@
small_tree = Dataset(meta_path / "small_tree.yaml")


# LARGE DATASETS
LARGE_DATASETS = [cyber]
MEDIUM_DATASETS = [polbooks]

# <10,000 lines
MEDIUM_DATASETS = [netscience, polbooks]
SMALL_DATASETS = [karate, dolphins, netscience]

# <500 lines
SMALL_DATASETS = [karate, small_line, small_tree, dolphins]
RLY_SMALL_DATASETS = [small_line, small_tree]

# ALL
ALL_DATASETS = [karate, dolphins, netscience, polbooks, cyber,
small_line, small_tree]
ALL_DATASETS = [karate, dolphins, netscience, polbooks,
small_line, small_tree]

ALL_DATASETS_WGT = [karate, dolphins, netscience, polbooks,
small_line, small_tree]

TEST_GROUP = [dolphins, netscience]

DATASETS_KTRUSS = [polbooks]

DATASETS_UNDIRECTED = [karate_undirected, small_line, karate_asymmetric]
41 changes: 33 additions & 8 deletions python/cugraph/cugraph/experimental/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import cugraph
import cudf
import yaml
import os
from pathlib import Path
from cugraph.structure.graph_classes import Graph


class DefaultDownloadDir:
Expand Down Expand Up @@ -64,7 +64,6 @@ class Dataset:
The metadata file for the specific graph dataset, which includes
information on the name, type, url link, data loading format, graph
properties
"""
def __init__(self, meta_data_file_name):
with open(meta_data_file_name, 'r') as file:
Expand Down Expand Up @@ -118,22 +117,48 @@ def get_edgelist(self, fetch=False):

return self._edgelist

def get_graph(self, fetch=False):
def get_graph(self, fetch=False, create_using=Graph, ignore_weights=False):
"""
Return a Graph object.
Parameters
----------
fetch : Boolean (default=False)
Automatically fetch for the dataset from the 'url' location within
the YAML file.
Downloads the dataset from the web.
create_using: cugraph.Graph (instance or class), optional
(default=Graph)
Specify the type of Graph to create. Can pass in an instance to
create a Graph instance with specified 'directed' attribute.
ignore_weights : Boolean (default=False)
Ignores weights in the dataset if True, resulting in an
unweighted Graph. If False (the default), weights from the
dataset -if present- will be applied to the Graph. If the
dataset does not contain weights, the Graph returned will
be unweighted regardless of ignore_weights.
"""
if self._edgelist is None:
self.get_edgelist(fetch)

self._graph = cugraph.Graph(directed=self.metadata['is_directed'])
self._graph.from_cudf_edgelist(self._edgelist, source='src',
destination='dst')
if create_using is None:
self._graph = Graph()
elif isinstance(create_using, Graph):
attrs = {"directed": create_using.is_directed()}
self._graph = type(create_using)(**attrs)
elif type(create_using) is type:
self._graph = create_using()
else:
raise TypeError("create_using must be a cugraph.Graph "
"(or subclass) type or instance, got: "
f"{type(create_using)}")

if (len(self.metadata['col_names']) > 2 and not(ignore_weights)):
self._graph.from_cudf_edgelist(self._edgelist, source='src',
destination='dst', edge_attr='wgt')
else:
self._graph.from_cudf_edgelist(self._edgelist, source='src',
destination='dst')

return self._graph

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
name: karate-data
name: karate
file_type: .csv
author: Zachary W.
url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate-data.csv
url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate.csv
refs:
W. W. Zachary, An information flow model for conflict and fission in small groups,
Journal of Anthropological Research 33, 452-473 (1977).
delim: "\t"
delim: " "
col_names:
- src
- dst
- wgt
col_types:
- int32
- int32
- float32
has_loop: true
is_directed: true
is_multigraph: false
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: karate-data
file_type: .csv
author: Zachary W.
url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate-data.csv
refs:
W. W. Zachary, An information flow model for conflict and fission in small groups,
Journal of Anthropological Research 33, 452-473 (1977).
delim: "\t"
col_names:
- src
- dst
col_types:
- int32
- int32
has_loop: true
is_directed: true
is_multigraph: false
is_symmetric: true
number_of_edges: 156
number_of_nodes: 34
number_of_lines: 156
41 changes: 32 additions & 9 deletions python/cugraph/cugraph/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
import os
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from cugraph.experimental.datasets import (ALL_DATASETS)
from cugraph.experimental.datasets import (ALL_DATASETS, ALL_DATASETS_WGT,
SMALL_DATASETS)
from cugraph.structure import Graph


# =============================================================================
# Pytest Setup / Teardown - called for each test function
# =============================================================================

dataset_path = Path(__file__).parents[4] / "datasets"


# Use this to simulate a fresh API import
@pytest.fixture
def datasets():
Expand Down Expand Up @@ -125,25 +130,19 @@ def test_fetch(dataset, datasets):

@pytest.mark.parametrize("dataset", ALL_DATASETS)
def test_get_edgelist(dataset, datasets):
tmpd = TemporaryDirectory()
datasets.set_download_dir(tmpd.name)
datasets.set_download_dir(dataset_path)
E = dataset.get_edgelist(fetch=True)

assert E is not None

tmpd.cleanup()


@pytest.mark.parametrize("dataset", ALL_DATASETS)
def test_get_graph(dataset, datasets):
tmpd = TemporaryDirectory()
datasets.set_download_dir(tmpd.name)
datasets.set_download_dir(dataset_path)
G = dataset.get_graph(fetch=True)

assert G is not None

tmpd.cleanup()


@pytest.mark.parametrize("dataset", ALL_DATASETS)
def test_metadata(dataset):
Expand All @@ -167,3 +166,27 @@ def test_get_path(dataset, datasets):
def test_get_path_raises(dataset):
with pytest.raises(RuntimeError):
dataset.get_path()


@pytest.mark.parametrize("dataset", ALL_DATASETS_WGT)
def test_weights(dataset, datasets):
datasets.set_download_dir(dataset_path)

G_w = dataset.get_graph(fetch=True)
G = dataset.get_graph(fetch=True, ignore_weights=True)

assert G_w.is_weighted()
assert not G.is_weighted()


@pytest.mark.parametrize("dataset", SMALL_DATASETS)
def test_create_using(dataset, datasets):
datasets.set_download_dir(dataset_path)

G_d = dataset.get_graph()
G_t = dataset.get_graph(create_using=Graph)
G = dataset.get_graph(create_using=Graph(directed=True))

assert not G_d.is_directed()
assert not G_t.is_directed()
assert G.is_directed()

0 comments on commit 2a57740

Please sign in to comment.