diff --git a/cogdl/data/sampler.py b/cogdl/data/sampler.py index 21ffb59c..8218552b 100644 --- a/cogdl/data/sampler.py +++ b/cogdl/data/sampler.py @@ -160,12 +160,15 @@ def __getitem__(self, idx): """ batch = self.node_idx[idx * self.batch_size : (idx + 1) * self.batch_size] self.random_walker.build_up(self.edge_index, self.total_num_nodes) - walk_res=self.random_walker.walk_one(batch,length=1,p=0.0) - + walk_res = self.random_walker.walk( + batch, walk_length=2, parallel=False + )[:,1] + neg_batch = torch.randint(0, self.total_num_nodes, (batch.numel(), ), dtype=torch.int64) pos_batch=torch.tensor(walk_res) - batch = torch.cat([batch, pos_batch, neg_batch], dim=0) + if self.sizes != [-1]: + batch = torch.cat([batch, pos_batch, neg_batch], dim=0) node_id = batch adj_list = [] for size in self.sizes: diff --git a/cogdl/models/__init__.py b/cogdl/models/__init__.py index c13d7c6a..4eb0cd75 100644 --- a/cogdl/models/__init__.py +++ b/cogdl/models/__init__.py @@ -104,7 +104,7 @@ def build_model(args): "sortpool": "cogdl.models.nn.sortpool.SortPool", "srgcn": "cogdl.models.nn.srgcn.SRGCN", "gcc": "cogdl.models.nn.gcc_model.GCCModel", - "unsup_graphsage": "cogdl.models.nn.graphsage.Graphsage", + "unsup_graphsage": "cogdl.models.nn.graphsage.UnsupGraphsage", "graphsaint": "cogdl.models.nn.graphsaint.GraphSAINT", "m3s": "cogdl.models.nn.m3s.M3S", "moe_gcn": "cogdl.models.nn.moe_gcn.MoEGCN", diff --git a/cogdl/models/nn/graphsage.py b/cogdl/models/nn/graphsage.py index afa96be8..b2a507d9 100644 --- a/cogdl/models/nn/graphsage.py +++ b/cogdl/models/nn/graphsage.py @@ -189,3 +189,19 @@ def forward(self, graph): for layer in self.layers: x = layer(graph, x) return x + +class UnsupGraphsage(Graphsage): + def __init__(self, num_features, num_classes, hidden_size, num_layers, sample_size, dropout, aggr): + super(Graphsage, self).__init__() + assert num_layers == len(sample_size) + self.adjlist = {} + self.num_features = num_features + self.num_classes = num_classes + self.hidden_size = hidden_size + self.num_layers = num_layers + self.sample_size = sample_size + self.dropout = dropout + shapes = [num_features] + hidden_size * num_layers + self.convs = nn.ModuleList( + [SAGELayer(shapes[layer], shapes[layer + 1], aggr=aggr) for layer in range(num_layers)] + ) \ No newline at end of file diff --git a/cogdl/utils/sampling.py b/cogdl/utils/sampling.py index 06567b45..806f1adb 100644 --- a/cogdl/utils/sampling.py +++ b/cogdl/utils/sampling.py @@ -110,25 +110,3 @@ def walk(self, start, walk_length, restart_p=0.0, parallel=True): result = random_walk_single(start, walk_length, self.indptr, self.indices, restart_p) result = np.array(result, dtype=np.int64) return result - - def walk_one(self, start, length, p): - walk_res = [np.zeros(length, dtype=np.int32)] * len(start) - p = 0.0 - for i in range(len(start)): - node = start[i] - result = [np.int32(0)] * length - index = np.int32(0) - _node = node - while index < length: - start1 = self.indptr[node] - end1 = self.indptr[node + 1] - sample1 = random.randint(start1, end1 - 1) - node = self.indices[sample1] - if np.random.uniform(0, 1) > p: - result[index] = node - else: - result[index] = _node - index += 1 - k = int(np.floor(np.random.rand() * len(result))) - walk_res[i] = result[k] - return walk_res diff --git a/cogdl/wrappers/model_wrapper/node_classification/unsup_graphsage_mw.py b/cogdl/wrappers/model_wrapper/node_classification/unsup_graphsage_mw.py index 3577ca2f..6fee7461 100644 --- a/cogdl/wrappers/model_wrapper/node_classification/unsup_graphsage_mw.py +++ b/cogdl/wrappers/model_wrapper/node_classification/unsup_graphsage_mw.py @@ -1,25 +1,28 @@ import torch import numpy as np -from cogdl.utils import RandomWalker -from cogdl.wrappers.tools.wrapper_utils import evaluate_node_embeddings_using_logreg +from cogdl.wrappers.tools.wrapper_utils import evaluate_node_embeddings_using_liblinear from .. import UnsupervisedModelWrapper - +from torch.nn import functional as F class UnsupGraphSAGEModelWrapper(UnsupervisedModelWrapper): @staticmethod def add_args(parser): # fmt: off + parser.add_argument("--num-shuffle", type=int, default=1) + parser.add_argument("--training-percents", default=[0.2], type=float, nargs="+") parser.add_argument("--walk-length", type=int, default=10) parser.add_argument("--negative-samples", type=int, default=30) # fmt: on - def __init__(self, model, optimizer_cfg, walk_length, negative_samples): + def __init__(self, model, optimizer_cfg, walk_length, negative_samples, num_shuffle=1, training_percents=[0.1]): super(UnsupGraphSAGEModelWrapper, self).__init__() self.model = model self.optimizer_cfg = optimizer_cfg self.walk_length = walk_length self.num_negative_samples = negative_samples + self.num_shuffle = num_shuffle + self.training_percents = training_percents def train_step(self, batch): @@ -27,26 +30,25 @@ def train_step(self, batch): out = self.model(x_src,adjs) out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0) - pos_loss = torch.log(torch.sigmoid((out * pos_out).sum(-1)).mean()) - neg_loss = torch.log(torch.sigmoid(-(out * neg_out).sum(-1)).mean()) + pos_loss = F.logsigmoid((out * pos_out).sum(-1)).mean() + neg_loss = F.logsigmoid(-(out * neg_out).sum(-1)).mean() loss = -pos_loss - neg_loss return loss - - def test_step(self, batch): - dataset, test_loader = batch + def test_step(self, graph): + dataset, test_loader = graph graph = dataset.data - if hasattr(self.model, "inference"): - pred = self.model.inference(graph.x, test_loader) + with torch.no_grad(): + if hasattr(self.model, "inference"): + pred = self.model.inference(graph.x, test_loader) + else: + pred = self.model(graph) + if len(graph.y.shape) > 1: + self.label_matrix = graph.y.numpy() else: - pred = self.model(graph) - pred= pred.split(pred.size(0) // 3, dim=0)[0] - pred = pred[graph.test_mask] - y = graph.y[graph.test_mask] - - metric = self.evaluate(pred, y, metric="auto") - self.note("test_loss", self.default_loss_fn(pred, y)) - self.note("test_metric", metric) + self.label_matrix = np.zeros((graph.num_nodes, graph.num_classes), dtype=int) + self.label_matrix[range(graph.num_nodes), graph.y.numpy()] = 1 + return evaluate_node_embeddings_using_liblinear(pred, self.label_matrix, self.num_shuffle, self.training_percents) def setup_optimizer(self): diff --git a/tests/models/ssl/test_contrastive_models.py b/tests/models/ssl/test_contrastive_models.py index f03c2cf2..12650250 100644 --- a/tests/models/ssl/test_contrastive_models.py +++ b/tests/models/ssl/test_contrastive_models.py @@ -49,7 +49,7 @@ def test_unsupervised_graphsage(): args.epochs = 2 args.checkpoint_path = "graphsage.pt" ret = train(args) - assert ret["test_acc"] > 0 + assert ret["micro-f1 0.1"] > 0 def test_dgi():