-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
104 lines (86 loc) · 3.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
import os
from functools import namedtuple
import scipy.sparse
from sklearn.preprocessing import StandardScaler
import dgl
import numpy as np
import torch
from sklearn.metrics import f1_score
class Logger(object):
'''A custom logger to log stdout to a logging file.'''
def __init__(self, path):
"""Initialize the logger.
Parameters
---------
path : str
The file path to be stored in.
"""
self.path = path
def write(self, s):
with open(self.path, 'a') as f:
f.write(str(s))
print(s)
return
def save_log_dir(args):
log_dir = './log/{}/{}'.format(args.dataset, args.log_dir)
os.makedirs(log_dir, exist_ok=True)
return log_dir
def calc_f1(y_true, y_pred, multilabel):
if multilabel:
y_pred[y_pred > 0] = 1
y_pred[y_pred <= 0] = 0
else:
y_pred = np.argmax(y_pred, axis=1)
return f1_score(y_true, y_pred, average="micro"), \
f1_score(y_true, y_pred, average="macro")
def evaluate(model, g, labels, mask, multilabel=False):
model.eval()
with torch.no_grad():
logits = model(g)
logits = logits[mask]
labels = labels[mask]
f1_mic, f1_mac = calc_f1(labels.cpu().numpy(),
logits.cpu().numpy(), multilabel)
return f1_mic, f1_mac
# load data of GraphSAINT and convert them to the format of dgl
def load_data(args, multilabel):
prefix = "data/{}".format(args.dataset)
DataType = namedtuple('Dataset', ['num_classes', 'train_nid', 'g'])
adj_full = scipy.sparse.load_npz('./{}/adj_full.npz'.format(prefix)).astype(np.bool)
g = dgl.from_scipy(adj_full)
num_nodes = g.num_nodes()
adj_train = scipy.sparse.load_npz('./{}/adj_train.npz'.format(prefix)).astype(np.bool)
train_nid = np.array(list(set(adj_train.nonzero()[0])))
role = json.load(open('./{}/role.json'.format(prefix)))
mask = np.zeros((num_nodes,), dtype=bool)
train_mask = mask.copy()
train_mask[role['tr']] = True
val_mask = mask.copy()
val_mask[role['va']] = True
test_mask = mask.copy()
test_mask[role['te']] = True
feats = np.load('./{}/feats.npy'.format(prefix))
scaler = StandardScaler()
scaler.fit(feats[train_nid])
feats = scaler.transform(feats)
class_map = json.load(open('./{}/class_map.json'.format(prefix)))
class_map = {int(k): v for k, v in class_map.items()}
if multilabel:
# Multi-label binary classification
num_classes = len(list(class_map.values())[0])
class_arr = np.zeros((num_nodes, num_classes))
for k, v in class_map.items():
class_arr[k] = v
else:
num_classes = max(class_map.values()) - min(class_map.values()) + 1
class_arr = np.zeros((num_nodes,))
for k, v in class_map.items():
class_arr[k] = v
g.ndata['feat'] = torch.tensor(feats, dtype=torch.float)
g.ndata['label'] = torch.tensor(class_arr, dtype=torch.float if multilabel else torch.long)
g.ndata['train_mask'] = torch.tensor(train_mask, dtype=torch.bool)
g.ndata['val_mask'] = torch.tensor(val_mask, dtype=torch.bool)
g.ndata['test_mask'] = torch.tensor(test_mask, dtype=torch.bool)
data = DataType(g=g, num_classes=num_classes, train_nid=train_nid)
return data