Skip to content

Commit

Permalink
[Model] update APIs for gensim 4.x (#361)
Browse files Browse the repository at this point in the history
* change parameters' name for gensim.models.Word2Vec

* update requirements for gensim

* update other APIs with gensim 4.2.0

* update test_deepwalk

* Delete .vscode directory

* Update setup.py

* 4.2.0
  • Loading branch information
Saltsmart authored Aug 4, 2022
1 parent 169c607 commit f5f92d6
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 25 deletions.
4 changes: 2 additions & 2 deletions cogdl/models/emb/deepwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def forward(self, graph, embedding_model_creator=Word2Vec, return_dict=False):
print("training word2vec...")
model = embedding_model_creator(
walks,
size=self.dimension,
vector_size=self.dimension,
window=self.window_size,
min_count=0,
sg=1,
workers=self.worker,
iter=self.iteration,
epochs=self.iteration,
)
id2node = dict([(vid, node) for vid, node in enumerate(nx_g.nodes())])
embeddings = np.asarray([model.wv[str(id2node[i])] for i in range(len(id2node))])
Expand Down
8 changes: 4 additions & 4 deletions cogdl/models/emb/dgk.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,17 @@ def forward(self, graphs, **kwargs):

model = Word2Vec(
self.gl_collections,
size=self.hidden_dim,
vector_size=self.hidden_dim,
window=self.window,
min_count=self.min_count,
sample=self.sampling_rate,
workers=self.n_workers,
iter=self.epochs,
epochs=self.epochs,
alpha=self.alpha,
)
vectors = np.asarray([model.wv[str(node)] for node in model.wv.index2word])
vectors = np.asarray([model.wv[str(node)] for node in model.wv.index_to_key])
S = vectors.dot(vectors.T)
node2id = dict(zip(model.wv.index2word, range(len(model.wv.index2word))))
node2id = dict(zip(model.wv.index_to_key, range(len(model.wv.index_to_key))))

num_graph, size_vocab = len(graphs), len(node2id)
norm_prob = np.zeros((num_graph, size_vocab))
Expand Down
21 changes: 11 additions & 10 deletions cogdl/models/emb/gatne.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import networkx as nx
from collections import defaultdict
from gensim.models.keyedvectors import Vocab
from gensim.models.keyedvectors import Vocab # Retained for now to ease the loading of older models.
# See: https://radimrehurek.com/gensim/models/keyedvectors.html?highlight=vocab#gensim.models.keyedvectors.CompatVocab
import random
import math
import tqdm
Expand Down Expand Up @@ -110,12 +111,12 @@ def __init__(
def forward(self, network_data):
device = "cpu" if not torch.cuda.is_available() else "cuda"
all_walks = generate_walks(network_data, self.walk_num, self.walk_length, schema=self.schema)
vocab, index2word = generate_vocab(all_walks)
vocab, index_to_key = generate_vocab(all_walks)
train_pairs = generate_pairs(all_walks, vocab)

edge_types = list(network_data.keys())

num_nodes = len(index2word)
num_nodes = len(index_to_key)
edge_type_count = len(edge_types)

epochs = self.epochs
Expand Down Expand Up @@ -189,7 +190,7 @@ def forward(self, network_data):
node_neigh = torch.tensor([neighbors[i] for _ in range(edge_type_count)]).to(device)
node_emb = model(train_inputs, train_types, node_neigh)
for j in range(edge_type_count):
final_model[edge_types[j]][index2word[i]] = node_emb[j].cpu().detach().numpy()
final_model[edge_types[j]][index_to_key[i]] = node_emb[j].cpu().detach().numpy()
return final_model


Expand Down Expand Up @@ -349,7 +350,7 @@ def generate_pairs(all_walks, vocab, window_size=5):


def generate_vocab(all_walks):
index2word = []
index_to_key = []
raw_vocab = defaultdict(int)

for walks in all_walks:
Expand All @@ -359,14 +360,14 @@ def generate_vocab(all_walks):

vocab = {}
for word, v in raw_vocab.items():
vocab[word] = Vocab(count=v, index=len(index2word))
index2word.append(word)
vocab[word] = Vocab(count=v, index=len(index_to_key))
index_to_key.append(word)

index2word.sort(key=lambda word: vocab[word].count, reverse=True)
for i, word in enumerate(index2word):
index_to_key.sort(key=lambda word: vocab[word].count, reverse=True)
for i, word in enumerate(index_to_key):
vocab[word].index = i

return vocab, index2word
return vocab, index_to_key


def get_batches(pairs, neighbors, batch_size):
Expand Down
4 changes: 2 additions & 2 deletions cogdl/models/emb/metapath2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def forward(self, data):
walks = [[str(node) for node in walk] for walk in walks]
model = Word2Vec(
walks,
size=self.dimension,
vector_size=self.dimension,
window=self.window_size,
min_count=0,
sg=1,
workers=self.worker,
iter=self.iteration,
epochs=self.iteration,
)
id2node = dict([(vid, node) for vid, node in enumerate(G.nodes())])
embeddings = np.asarray([model.wv[str(id2node[i])] for i in range(len(id2node))])
Expand Down
4 changes: 2 additions & 2 deletions cogdl/models/emb/node2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def forward(self, graph, return_dict=False):
walks = [[str(node) for node in walk] for walk in walks]
model = Word2Vec(
walks,
size=self.dimension,
vector_size=self.dimension,
window=self.window_size,
min_count=0,
sg=1,
workers=self.worker,
iter=self.iteration,
epochs=self.iteration,
)
id2node = dict([(vid, node) for vid, node in enumerate(G.nodes())])
embeddings = np.asarray([model.wv[str(id2node[i])] for i in range(len(id2node))])
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ matplotlib
tqdm
numpy>=1.21
scipy
gensim<4.0
gensim>=4.0
grave
scikit_learn==0.24.2
tabulate
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def find_version(filename):
"tqdm",
"numpy>=1.21",
"scipy",
"gensim<4.0",
"gensim>=4.0",
"grave",
"scikit_learn",
"tabulate",
Expand Down
6 changes: 3 additions & 3 deletions tests/models/emb/test_deepwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, data: Dict[str, List[float]]) -> None:
embed_3 = [0.3, 0.2, 0.1, -0.1]


def creator(walks, size, window, min_count, sg, workers, iter):
def creator(walks, vector_size, window, min_count, sg, workers, epochs):
return Word2VecFake({"0": embed_1, "1": embed_2, "2": embed_3})


Expand Down Expand Up @@ -93,9 +93,9 @@ def test_will_pass_correct_number_of_walks():
graph = Graph(edge_index=(torch.LongTensor([0, 1]), torch.LongTensor([1, 2])))
captured_walks_no = []

def creator_mocked(walks, size, window, min_count, sg, workers, iter):
def creator_mocked(walks, vector_size, window, min_count, sg, workers, epochs):
captured_walks_no.append(len(walks))
return creator(walks, size, window, min_count, sg, workers, iter)
return creator(walks, vector_size, window, min_count, sg, workers, epochs)

model(graph, creator_mocked)
assert captured_walks_no[0] == args.walk_num * graph.num_nodes
Expand Down

0 comments on commit f5f92d6

Please sign in to comment.