Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix save/load_word2vec_format methods for FastText model. Fix #1743 #1755

Merged
merged 8 commits into from
Dec 6, 2017
8 changes: 8 additions & 0 deletions gensim/models/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ def word_vec(self, word, use_norm=False):
def load_fasttext_format(cls, *args, **kwargs):
return Ft_Wrapper.load_fasttext_format(*args, **kwargs)

@classmethod
def load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', unicode_errors='strict',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Necessary? Personally, I would leave this out. I don't see any problem with getting a "deprecation error" instead of "not implemented error". Plus, it's weird to call a method that you know will fail. This is likely to cause more confusion than good. @menshikh-iv, what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Removing this function from here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@janpom agree with you

limit=None, datatype=REAL):
"""Not suppported. Use gensim.models.KeyedVectors.load_word2vec_format instead."""
return FastTextKeyedVectors.load_word2vec_format(
fname, fvocab=fvocab, binary=binary, encoding=encoding, unicode_errors=unicode_errors,
limit=limit, datatype=datatype)

def save(self, *args, **kwargs):
kwargs['ignore'] = kwargs.get('ignore', ['syn0norm', 'syn0_vocab_norm', 'syn0_ngrams_norm'])
super(FastText, self).save(*args, **kwargs)
6 changes: 6 additions & 0 deletions gensim/models/wrappers/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ def __contains__(self, word):
char_ngrams = compute_ngrams(word, self.min_n, self.max_n)
return any(ng in self.ngrams for ng in char_ngrams)

@classmethod
def load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', unicode_errors='strict',
Copy link
Contributor

@jayantj jayantj Dec 5, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One tiny change I might suggest here would be to use *args and **kwargs instead. I like following this practice in situations where the subclass doesn't make use of the function params itself. It is useful because if any params in the parent class change in the future, the subclass function signature doesn't have to be modified.

limit=None, datatype=REAL):
"""Not suppported. Use gensim.models.KeyedVectors.load_word2vec_format instead."""
raise NotImplementedError("Not supported. Use gensim.models.KeyedVectors.load_word2vec_format instead.")


class FastText(Word2Vec):
"""
Expand Down
11 changes: 11 additions & 0 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from gensim import utils
from gensim.models import keyedvectors
from gensim.models.word2vec import LineSentence
from gensim.models.fasttext import FastText as FT_gensim
from gensim.models.wrappers.fasttext import FastTextKeyedVectors
Expand Down Expand Up @@ -455,6 +456,16 @@ def test_cbow_neg_online(self):
)
self.online_sanity(model)

def test_persistence_word2vec_format(self):
"""Test storing/loading the model in word2vec format."""
tmpf = get_tmpfile('gensim_fasttext_w2v_format.tst')
model = FT_gensim(sentences, min_count=1)
model.wv.save_word2vec_format(tmpf, binary=True)
loaded_model_kv = keyedvectors.KeyedVectors.load_word2vec_format(tmpf, binary=True)
self.assertEqual(len(model.wv.vocab), len(loaded_model_kv.vocab))
self.assertTrue((model.wv.syn0 == loaded_model_kv.syn0).all())
self.assertRaises(NotImplementedError, FT_gensim.load_word2vec_format, tmpf)
self.assertRaises(NotImplementedError, FastTextKeyedVectors.load_word2vec_format, tmpf)

if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down