Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Triple_Link_Prediction #371

Merged
merged 12 commits into from
Aug 11, 2022
182 changes: 179 additions & 3 deletions cogdl/datasets/kg_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,186 @@
import os.path as osp

import numpy as np

import torch
from cogdl.data import Graph, Dataset
from cogdl.utils import download_url


class BidirectionalOneShotIterator(object):


def __init__(self, dataloader_head, dataloader_tail):
self.iterator_head = self.one_shot_iterator(dataloader_head)
self.iterator_tail = self.one_shot_iterator(dataloader_tail)
self.step = 0

def __next__(self):
self.step += 1
if self.step % 2 == 0:
data = next(self.iterator_head)
else:
data = next(self.iterator_tail)
return data

@staticmethod
def one_shot_iterator(dataloader):
"""
Transform a PyTorch Dataloader into python iterator
"""
while True:
for data in dataloader:
yield data


class TestDataset(torch.utils.data.Dataset):
def __init__(self, triples, all_true_triples, nentity, nrelation, mode):
self.len = len(triples)
self.triple_set = set(all_true_triples)
self.triples = triples
self.nentity = nentity
self.nrelation = nrelation
self.mode = mode

def __len__(self):
return self.len

def __getitem__(self, idx):
head, relation, tail = self.triples[idx]

if self.mode == "head-batch":
tmp = [
(0, rand_head) if (rand_head, relation, tail) not in self.triple_set else (-1, head)
for rand_head in range(self.nentity)
]
tmp[head] = (0, head)
elif self.mode == "tail-batch":
tmp = [
(0, rand_tail) if (head, relation, rand_tail) not in self.triple_set else (-1, tail)
for rand_tail in range(self.nentity)
]
tmp[tail] = (0, tail)
else:
raise ValueError("negative batch mode %s not supported" % self.mode)

tmp = torch.LongTensor(tmp)
filter_bias = tmp[:, 0].float()
negative_sample = tmp[:, 1]

positive_sample = torch.LongTensor((head, relation, tail))

return positive_sample, negative_sample, filter_bias, self.mode

@staticmethod
def collate_fn(data):
positive_sample = torch.stack([_[0] for _ in data], dim=0)
negative_sample = torch.stack([_[1] for _ in data], dim=0)
filter_bias = torch.stack([_[2] for _ in data], dim=0)
mode = data[0][3]
return positive_sample, negative_sample, filter_bias, mode


class TrainDataset(torch.utils.data.Dataset):
def __init__(self, triples, nentity, nrelation, negative_sample_size, mode):
self.len = len(triples)
self.triples = triples
self.triple_set = set(triples)
self.nentity = nentity
self.nrelation = nrelation
self.negative_sample_size = negative_sample_size
self.mode = mode
self.count = self.count_frequency(triples)
self.true_head, self.true_tail = self.get_true_head_and_tail(self.triples)

def __len__(self):
return self.len

def __getitem__(self, idx):
positive_sample = self.triples[idx]

head, relation, tail = positive_sample

subsampling_weight = self.count[(head, relation)] + self.count[(tail, -relation - 1)]
subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight]))

negative_sample_list = []
negative_sample_size = 0

while negative_sample_size < self.negative_sample_size:
negative_sample = np.random.randint(self.nentity, size=self.negative_sample_size * 2)
if self.mode == "head-batch":
mask = np.in1d(negative_sample, self.true_head[(relation, tail)], assume_unique=True, invert=True)
elif self.mode == "tail-batch":
mask = np.in1d(negative_sample, self.true_tail[(head, relation)], assume_unique=True, invert=True)
else:
raise ValueError("Training batch mode %s not supported" % self.mode)
negative_sample = negative_sample[mask]
negative_sample_list.append(negative_sample)
negative_sample_size += negative_sample.size

negative_sample = np.concatenate(negative_sample_list)[: self.negative_sample_size]

negative_sample = torch.LongTensor(negative_sample)

positive_sample = torch.LongTensor(positive_sample)

return positive_sample, negative_sample, subsampling_weight, self.mode

@staticmethod
def collate_fn(data):
positive_sample = torch.stack([_[0] for _ in data], dim=0)
negative_sample = torch.stack([_[1] for _ in data], dim=0)
subsample_weight = torch.cat([_[2] for _ in data], dim=0)
mode = data[0][3]
return positive_sample, negative_sample, subsample_weight, mode

@staticmethod
def count_frequency(triples, start=4):
"""
Get frequency of a partial triple like (head, relation) or (relation, tail)
The frequency will be used for subsampling like word2vec
"""
count = {}
for head, relation, tail in triples:
if (head, relation) not in count:
count[(head, relation)] = start
else:
count[(head, relation)] += 1

if (tail, -relation - 1) not in count:
count[(tail, -relation - 1)] = start
else:
count[(tail, -relation - 1)] += 1
return count

@staticmethod
def get_true_head_and_tail(triples):
"""
Build a dictionary of true triples that will
be used to filter these true triples for negative sampling
"""

true_head = {}
true_tail = {}

for head, relation, tail in triples:
if (head, relation) not in true_tail:
true_tail[(head, relation)] = []
true_tail[(head, relation)].append(tail)
if (relation, tail) not in true_head:
true_head[(relation, tail)] = []
true_head[(relation, tail)].append(head)

for relation, tail in true_head:
true_head[(relation, tail)] = np.array(list(set(true_head[(relation, tail)])))
for head, relation in true_tail:
true_tail[(head, relation)] = np.array(list(set(true_tail[(head, relation)])))

return true_head, true_tail




def read_triplet_data(folder):
filenames = ["train2id.txt", "valid2id.txt", "test2id.txt"]
count = 0
Expand All @@ -18,7 +194,6 @@ def read_triplet_data(folder):
relation_dic = {}
for filename in filenames:
with open(osp.join(folder, filename), "r") as f:
_ = int(f.readline().strip())
if "train" in filename:
train_start_idx = len(triples)
elif "valid" in filename:
Expand All @@ -27,6 +202,8 @@ def read_triplet_data(folder):
test_start_idx = len(triples)
for line in f:
items = line.strip().split()
if len(items) != 3:
continue
edge_index.append([int(items[0]), int(items[1])])
edge_attr.append(int(items[2]))
triples.append((int(items[0]), int(items[2]), int(items[1])))
Expand Down Expand Up @@ -110,9 +287,8 @@ def get(self, idx):

def download(self):
for name in self.raw_file_names:
# download_url("{}/{}/{}".format(self.url, self.name, name), self.raw_dir)
download_url(self.url.format(self.name, name), self.raw_dir, name=name)

def process(self):
(
data,
Expand Down
4 changes: 3 additions & 1 deletion cogdl/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,12 @@ def train(args): # noqa: C901
logger=args.logger,
log_path=args.log_path,
project=args.project,
no_test=args.no_test,
return_model=args.return_model,
nstage=args.nstage,
actnn=args.actnn,
fp16=args.fp16,
do_test=args.do_test,
do_valid=args.do_valid,
)

# Go!!!
Expand Down
4 changes: 4 additions & 0 deletions cogdl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def build_model(args):


SUPPORTED_MODELS = {
"transe":"cogdl.models.emb.transe.TransE",
"complex":"cogdl.models.emb.complex.ComplEx",
"distmult":"cogdl.models.emb.distmult.DistMult",
"rotate":"cogdl.models.emb.rotate.RotatE",
"hope": "cogdl.models.emb.hope.HOPE",
"spectral": "cogdl.models.emb.spectral.Spectral",
"hin2vec": "cogdl.models.emb.hin2vec.Hin2vec",
Expand Down
37 changes: 37 additions & 0 deletions cogdl/models/emb/complex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

from .. import BaseModel
from .knowledge_base import KGEModel


class ComplEx(KGEModel):
r"""
the implementation of ComplEx model from the paper `"Complex Embeddings for Simple Link Prediction"<http://proceedings.mlr.press/v48/trouillon16.pdf>`
borrowed from `KnowledgeGraphEmbedding<https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding>`
"""
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument("--embedding_size", type=int, default=500, help="Dimensionality of embedded vectors")
parser.add_argument("--gamma", type=float,default=12.0, help="Hyperparameter for embedding")
parser.add_argument("--double_entity_embedding", default=True)
parser.add_argument("--double_relation_embedding", default=True)
def score(self, head, relation, tail, mode):
re_head, im_head = torch.chunk(head, 2, dim=2)
re_relation, im_relation = torch.chunk(relation, 2, dim=2)
re_tail, im_tail = torch.chunk(tail, 2, dim=2)

if mode == "head-batch":
re_score = re_relation * re_tail + im_relation * im_tail
im_score = re_relation * im_tail - im_relation * re_tail
score = re_head * re_score + im_head * im_score
else:
re_score = re_head * re_relation - im_head * im_relation
im_score = re_head * im_relation + im_head * re_relation
score = re_score * re_tail + im_score * im_tail

score = score.sum(dim=2)
return score
25 changes: 25 additions & 0 deletions cogdl/models/emb/distmult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from .. import BaseModel
from .knowledge_base import KGEModel


class DistMult(KGEModel):
r"""The DistMult model from the ICLR 2015 paper `"EMBEDDING ENTITIES AND RELATIONS FOR LEARNING AND INFERENCE IN KNOWLEDGE BASES"
<https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/ICLR2015_updated.pdf>`
borrowed from `KnowledgeGraphEmbedding<https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding>`
"""

def __init__(
self, nentity, nrelation, hidden_dim, gamma, double_entity_embedding=False, double_relation_embedding=False
):
super(DistMult, self).__init__(
nentity, nrelation, hidden_dim, gamma, double_entity_embedding, double_relation_embedding
)

def score(self, head, relation, tail, mode):
if mode == "head-batch":
score = head * (relation * tail)
else:
score = (head * relation) * tail

score = score.sum(dim=2)
return score
Loading