Skip to content

Commit

Permalink
[Dataset] update ogbl datasets (#358)
Browse files Browse the repository at this point in the history
* update ogbl datasets

* update ogbl dataset

* update ogbl test

* fix bugs for ogbl tests

* update init file for ogb datasets

* update init file

* fix comma bug

* fix bugs

* updata for obgl

Co-authored-by: xinjie zhang <[email protected]>
  • Loading branch information
Diego0511 and xinjie zhang authored Aug 1, 2022
1 parent d5f268f commit 169c607
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
4 changes: 4 additions & 0 deletions cogdl/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def build_dataset_from_path(data_path, dataset=None):
"ogbg-molpcba": "cogdl.datasets.ogb.OGBMolpcbaDataset",
"ogbg-ppa": "cogdl.datasets.ogb.OGBPpaDataset",
"ogbg-code": "cogdl.datasets.ogb.OGBCodeDataset",
"ogbl-ppa": "cogdl.datasets.ogb.OGBLPpaDataset",
"ogbl-ddi": "cogdl.datasets.ogb.OGBLDdiDataset",
"ogbl-collab": "cogdl.datasets.ogb.OGBLCollabDataset",
"ogbl-citation2": "cogdl.datasets.ogb.OGBLCitation2Dataset",
"amazon": "cogdl.datasets.gatne.AmazonDataset",
"twitter": "cogdl.datasets.gatne.TwitterDataset",
"youtube": "cogdl.datasets.gatne.YouTubeDataset",
Expand Down
83 changes: 83 additions & 0 deletions cogdl/datasets/ogb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ogb.nodeproppred import NodePropPredDataset
from ogb.nodeproppred import Evaluator as NodeEvaluator
from ogb.graphproppred import GraphPropPredDataset
from ogb.linkproppred import LinkPropPredDataset

from cogdl.data import Dataset, Graph, DataLoader
from cogdl.utils import CrossEntropyLoss, Accuracy, remove_self_loops, coalesce, BCEWithLogitsLoss
Expand Down Expand Up @@ -234,3 +235,85 @@ class OGBCodeDataset(OGBGDataset):
def __init__(self, data_path="data"):
dataset = "ogbg-code"
super(OGBCodeDataset, self).__init__(data_path, dataset)


#This part is for ogbl datasets

class OGBLDataset(Dataset):
def __init__(self, root, name):
"""
- name (str): name of the dataset
- root (str): root directory to store the dataset folder
"""

self.name = name

dataset = LinkPropPredDataset(name, root)
graph= dataset[0]
x = torch.tensor(graph["node_feat"]).contiguous() if graph["node_feat"] is not None else None
row, col = graph["edge_index"][0], graph["edge_index"][1]
row = torch.from_numpy(row)
col = torch.from_numpy(col)
edge_index = torch.stack([row, col], dim=0)
edge_attr = torch.as_tensor(graph["edge_feat"]) if graph["edge_feat"] is not None else graph["edge_feat"]
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
row = torch.cat([edge_index[0], edge_index[1]])
col = torch.cat([edge_index[1], edge_index[0]])

row, col, _ = coalesce(row, col)
edge_index = torch.stack([row, col], dim=0)

self.data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr, y=None)
self.data.num_nodes = graph["num_nodes"]

def get(self, idx):
assert idx == 0
return self.data

def get_loss_fn(self):
return CrossEntropyLoss()

def get_evaluator(self):
return Accuracy()

def _download(self):
pass

@property
def processed_file_names(self):
return "data_cogdl.pt"

def _process(self):
pass

def get_edge_split(self):
idx = self.dataset.get_edge_split()
train_edge = torch.from_numpy(idx['train']['edge'].T)
val_edge = torch.from_numpy(idx['valid']['edge'].T)
test_edge = torch.from_numpy(idx['test']['edge'].T)
return train_edge, val_edge, test_edge

class OGBLPpaDataset(OGBLDataset):
def __init__(self, data_path="data"):
dataset = "ogbl-ppa"
super(OGBLPpaDataset, self).__init__(data_path, dataset)


class OGBLCollabDataset(OGBLDataset):
def __init__(self, data_path="data"):
dataset = "ogbl-collab"
super(OGBLCollabDataset, self).__init__(data_path, dataset)


class OGBLDdiDataset(OGBLDataset):
def __init__(self, data_path="data"):
dataset = "ogbl-ddi"
super(OGBLDdiDataset, self).__init__(data_path, dataset)


class OGBLCitation2Dataset(OGBLDataset):
def __init__(self, data_path="data"):
dataset = "ogbl-citation2"
super(OGBLCitation2Dataset, self).__init__(data_path, dataset)


15 changes: 15 additions & 0 deletions tests/datasets/test_ogb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,22 @@ def test_ogbg_molhiv():
assert dataset.all_nodes == 1049163
assert len(dataset.data) == 41127

def test_ogbl_ddi():
args = build_args_from_dict({"dataset": "ogbl-ddi"})
assert args.dataset == "ogbl-ddi"
dataset = build_dataset(args)
data = dataset.data
assert data.num_nodes == 4267

def test_ogbl_collab():
args = build_args_from_dict({"dataset": "ogbl-collab"})
assert args.dataset == "ogbl-collab"
dataset = build_dataset(args)
data = dataset.data
assert data.num_nodes == 235868

if __name__ == "__main__":
test_ogbn_arxiv()
test_ogbg_molhiv()
test_ogbl_ddi()
test_ogbl_collab()

0 comments on commit 169c607

Please sign in to comment.