Skip to content

Commit

Permalink
Fix incorrect vectors learned during online training for FastText. Fix
Browse files Browse the repository at this point in the history
…#1752 (#1756)

* adds an assert in `online_sanity` test case to check `syn0` and `syn0_vocab` are different

* creates copy of `syn0_vocab` vector before adding with ngrams and assigning to `syn0`

* adds unit test for `get_vocab_word_vecs`
  • Loading branch information
manneshiva authored and menshikh-iv committed Dec 5, 2017
1 parent ea1f3cf commit 10bd7fc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gensim/models/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def __getitem__(self, word):

def get_vocab_word_vecs(self):
for w, v in self.wv.vocab.items():
word_vec = self.wv.syn0_vocab[v.index]
word_vec = np.copy(self.wv.syn0_vocab[v.index])
ngrams = self.wv.ngrams_word[w]
ngram_weights = self.wv.syn0_ngrams
for ngram in ngrams:
Expand Down
9 changes: 9 additions & 0 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,8 @@ def online_sanity(self, model):
self.assertTrue(all(['terrorism' not in l for l in others]))
model.build_vocab(others)
model.train(others, total_examples=model.corpus_count, epochs=model.iter)
# checks that `syn0` is different from `syn0_vocab`
self.assertFalse(np.all(np.equal(model.wv.syn0, model.wv.syn0_vocab)))
self.assertFalse('terrorism' in model.wv.vocab)
self.assertFalse('orism>' in model.wv.ngrams)
model.build_vocab(terro, update=True) # update vocab
Expand Down Expand Up @@ -456,6 +458,13 @@ def test_cbow_neg_online(self):
)
self.online_sanity(model)

def test_get_vocab_word_vecs(self):
model = FT_gensim(size=10, min_count=1, seed=42)
model.build_vocab(sentences)
original_syn0_vocab = np.copy(model.wv.syn0_vocab)
model.get_vocab_word_vecs()
self.assertTrue(np.all(np.equal(model.wv.syn0_vocab, original_syn0_vocab)))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down

0 comments on commit 10bd7fc

Please sign in to comment.