import os import torch import dgl import dgl.data.citation_graph as dglcitationgraph import torch_geometric as pyg import numpy as np import networkx as nx import scipy.sparse as sp import torch_geometric.utils as pygutils from os.path import join as opjoin def index_to_mask(index, size): mask = torch.zeros(size, dtype=torch.bool, device=index.device) mask[index] = 1 return mask def resplit(dataset, data, full_sup, num_classes, num_nodes, num_per_class): if dataset in ['cora', 'citeseer', 'pubmed']: if full_sup: perm = torch.randperm(data[2].shape[0]) test_index = perm[:500] val_index = perm[500:1500] train_index = perm[1500:] data[3] = index_to_mask(train_index, size=num_nodes) data[4] = index_to_mask(val_index, size=num_nodes) data[5] = index_to_mask(test_index, size=num_nodes) else: indices = [] for i in range(num_classes): index = (data[2].long() == i).nonzero().view(-1) index = index[torch.randperm(index.size(0))] indices.append(index) train_index = torch.cat([i[ : num_per_class] for i in indices], dim=0) rest_index = torch.cat([i[num_per_class : ] for i in indices], dim=0) rest_index = rest_index[torch.randperm(rest_index.size(0))] data[3] = index_to_mask(train_index, size=num_nodes) data[4] = index_to_mask(rest_index[:500], size=num_nodes) data[5] = index_to_mask(rest_index[500:1500], size=num_nodes) elif dataset in ['coauthorcs']: if full_sup: raise NotImplementedError else: train_index = [] val_index = [] test_index = [] for i in range(num_classes): index = (data[2].long() == i).nonzero().view(-1) index = index[torch.randperm(index.size(0))] if len(index) > num_per_class + 30: train_index.append(index[ : num_per_class]) val_index.append(index[num_per_class : num_per_class + 30]) test_index.append(index[num_per_class + 30:]) else: continue train_index = torch.cat(train_index) val_index = torch.cat(val_index) test_index = torch.cat(test_index) data[3] = index_to_mask(train_index, size=num_nodes) data[4] = index_to_mask(val_index, size=num_nodes) data[5] = index_to_mask(test_index, size=num_nodes) return data def load_data(config): if not config['data']['dataset'] in config['data']['all_datasets']: raise NotImplementedError else: if config['data']['implement'] == 'dgl': graph = get_dgl_dataset(dataroot=config['path']['dataroot'], dataset=config['data']['dataset']) graph = rewarp_graph(graph, config) features = torch.tensor(graph.features, dtype=torch.float) if config['data']['feature_prenorm']: features = row_normalization(features) labels = torch.tensor(graph.labels, dtype=torch.long) idx_train = torch.tensor(graph.train_mask, dtype=torch.bool) idx_val = torch.tensor(graph.val_mask, dtype=torch.bool) idx_test = torch.tensor(graph.test_mask, dtype=torch.bool) g_nx = graph.graph if config['data']['add_slflp']: g_nx.remove_edges_from(nx.selfloop_edges(g_nx)) g_nx.add_edges(zip(g_nx.nodes(), g_nx.nodes())) graph_skeleton = dgl.DGLGraph(g_nx) return [graph_skeleton, features, labels, idx_train, idx_val, idx_test] elif config['data']['implement'] == 'pyg': graph = get_pyg_dataset(dataroot=config['path']['dataroot'], dataset=config['data']['dataset']) graph = rewarp_graph(graph, config) data = graph.data idx_train = data.train_mask idx_val = data.val_mask idx_test = data.test_mask labels = data.y num_nodes = data.num_nodes features = data.x if config['data']['feature_prenorm']: features = row_normalization(features) edge_index = data.edge_index if config['data']['add_slflp']: edge_index = pygutils.add_self_loops(edge_index)[0] graph_skeleton = dgl.DGLGraph() graph_skeleton.add_nodes(num_nodes) graph_skeleton.add_edges(edge_index[0, :], edge_index[1, :]) return [graph_skeleton, features, labels, idx_train, idx_val, idx_test] else: raise NotImplementedError def dummy_normalization(mx): if isinstance(mx, np.ndarray) or isinstance(mx, sp.csr.csr_matrix): pass elif isinstance(mx, sp.lil.lil_matrix): mx = np.asarray(mx.todense()) else: raise NotImplementedError return mx def get_dgl_dataset(dataroot, dataset): dglcitationgraph._normalize = dummy_normalization dglcitationgraph._preprocess_features = dummy_normalization if dataset == 'cora': graph = dgl.data.CoraDataset() elif dataset in ['citeseer', 'pubmed']: graph = dgl.data.CitationGraphDataset(name=dataset) elif dataset == 'coauthorcs': np.load.__defaults__ = (None, True, True, 'ASCII') graph = dgl.data.Coauthor(name='cs') np.load.__defaults__ = (None, False, True, 'ASCII') else: raise NotImplementedError return graph def get_pyg_dataset(dataroot, dataset): if dataset in ['cora', 'citeseer', 'pubmed']: graph = pyg.datasets.Planetoid(root=opjoin(dataroot, dataset), name=dataset.capitalize()) elif dataset == 'coauthorcs': graph = pyg.datasets.Coauthor(root=opjoin(dataroot, dataset), name='CS') else: raise NotImplementedError return graph def normalize(mx): """Row-normalize sparse matrix""" rowsum = np.array(mx.sum(1)) r_inv = np.power(rowsum, -1).flatten() r_inv[np.isinf(r_inv)] = 0. r_mat_inv = sp.diags(r_inv) mx = r_mat_inv.dot(mx) return mx def accuracy(output, labels): preds = output.max(1)[1].type_as(labels) correct = preds.eq(labels).double() correct_sum = correct.sum() return correct_sum / len(labels), correct def sparse_mx_to_torch_sparse_tensor(sparse_mx): """Convert a scipy sparse matrix to a torch sparse tensor.""" sparse_mx = sparse_mx.tocoo().astype(np.float32) indices = torch.from_numpy( np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) values = torch.from_numpy(sparse_mx.data) shape = torch.Size(sparse_mx.shape) return torch.sparse.FloatTensor(indices, values, shape) # Actually Sacred will automatically save source code files added by its add_source_file method (please see line 28 in config.py) # But it adds a md5 hash string after the file name, and saves the source codes for every run in the same file # So here we manually call its save_file method to save the source codes of given name in the location speicified by us # And finally delete the source codes saved by Sacred def save_source(run): if run.observers: for source_file, _ in run.experiment_info['sources']: os.makedirs(os.path.dirname('{0}/source/{1}'.format(run.observers[0].dir, source_file)), exist_ok=True) run.observers[0].save_file(source_file, 'source/{0}'.format(source_file)) sacred_source_path = f'{run.observers[0].basedir}/_sources' # if os.path.exists(sacred_source_path): # shutil.rmtree(sacred_source_path) def adjust_learning_rate(optimizer, epoch, lr_down_epoch_list, logger): if epoch != 0 and epoch in lr_down_epoch_list: opt_name = list(dict(optimizer=optimizer).keys())[0] logger.info('update learning rate of ' + opt_name) for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * 0.1 logger.info(param_group['lr']) def check_before_pkl(data): if type(data) == list or type(data) == tuple: for each in data: check_before_pkl(each) elif type(data) == dict: for key in data.keys(): check_before_pkl(data[key]) else: assert not isinstance(data, torch.Tensor) def row_normalization(features): ## normalize the feature matrix by its row sum rowsum = features.sum(dim=1) inv_rowsum = torch.pow(rowsum, -1) inv_rowsum[torch.isinf(inv_rowsum)] = 0. features = features * inv_rowsum[..., None] return features def rewarp_graph(graph, config): if pyg.__version__ in ['1.4.2', '1.3.2']: pyg_corafull_type = pyg.datasets.cora_full.CoraFull else: pyg_corafull_type = pyg.datasets.citation_full.CoraFull # pylint: disable=no-member if isinstance(graph, dgl.data.gnn_benckmark.Coauthor) or \ isinstance(graph, dgl.data.gnn_benckmark.CoraFull): graph = PseudoDGLGraph(graph) pseudo_data = [None, None, graph.labels, None, None, None] _, _, _, train_mask, val_mask, test_mask = resplit(dataset=config['data']['dataset'], data=pseudo_data, full_sup=config['data']['full_sup'], num_classes=graph.num_classes, num_nodes=graph.num_nodes, num_per_class=config['data']['label_per_class']) graph.train_mask = train_mask graph.val_mask = val_mask graph.test_mask = test_mask elif isinstance(graph, pyg_corafull_type) or \ isinstance(graph, pyg.datasets.coauthor.Coauthor): pseudo_data = [None, None, graph.data.y, None, None, None] _, _, _, train_mask, val_mask, test_mask = resplit(dataset=config['data']['dataset'], data=pseudo_data, full_sup=config['data']['full_sup'], num_classes=torch.unique(graph.data.y).shape[0], num_nodes=graph.data.num_nodes, num_per_class=config['data']['label_per_class']) graph.data.train_mask = train_mask graph.data.val_mask = val_mask graph.data.test_mask = test_mask else: pass return graph class PseudoDGLGraph(): def __init__(self, graph): self.graph = graph.data[0].to_networkx() self.features = graph.data[0].ndata['feat'] self.labels = graph.data[0].ndata['label'] self.num_classes = torch.unique(self.labels).shape[0] self.num_nodes = graph.data[0].number_of_nodes() self.train_mask = None self.val_mask = None self.test_mask = None