-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Sklearn wrapper for RandomProjections Model #1395
Merged
menshikh-iv
merged 27 commits into
piskvorky:develop
from
chinmayapancholi13:rp_wrapper_scikitlearn
Jun 20, 2017
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
0c5bcb0
created new file for rpmodel_sklearn_wrapper
chinmayapancholi13 0810428
updated get_params, set_params functions
chinmayapancholi13 d67f047
correction in calling init function
chinmayapancholi13 a9ce401
added fit, transform, partial_fit function
chinmayapancholi13 05ad743
added tests for Rp model's sklearn wrapper
chinmayapancholi13 f1b9c4a
minor correction in docstring in LDA and LSI models
chinmayapancholi13 8696e54
added newline before class definition (PEP8)
chinmayapancholi13 fe2f947
removed 'corpus' from 'init' and set 'corpus' in 'fit'
chinmayapancholi13 7317173
updated docstring for 'fit' function
chinmayapancholi13 692be88
refactored code to use 'self.model'
chinmayapancholi13 a2ec746
code style changes
chinmayapancholi13 954715e
refactored wrapper and tests
chinmayapancholi13 6c3b819
removed 'self.corpus' attribute and refactored slightly
chinmayapancholi13 aee04ff
updated 'self.__model' to 'self.gensim_model'
chinmayapancholi13 a73dacc
updated test data
chinmayapancholi13 da602d9
updated 'fit' and 'transform' methods
chinmayapancholi13 c1087ac
updated 'testTransform' test
chinmayapancholi13 00f5336
PEP8 change
chinmayapancholi13 376959d
updated 'testTransform' test
chinmayapancholi13 9c888d6
added 'NotFittedError' in 'transform' function
chinmayapancholi13 373c36c
added 'testPersistence' and 'testModelNotFitted' tests
chinmayapancholi13 f3c3601
added input 'docs' description in 'transform' function
chinmayapancholi13 ab90b68
added 'testPipeline' test
chinmayapancholi13 928c7f2
replaced 'text_lda' variable with 'text_rp'
chinmayapancholi13 cf13c9a
updated 'testPersistence' test
chinmayapancholi13 cde12f2
set fixed seed in 'testPipeline' test
chinmayapancholi13 26cd2df
Merge branch 'develop' into rp_wrapper_scikitlearn
menshikh-iv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
79 changes: 79 additions & 0 deletions
79
gensim/sklearn_integration/sklearn_wrapper_gensim_rpmodel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <[email protected]> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Scikit learn interface for gensim for easy use of gensim with scikit-learn | ||
Follows scikit-learn API conventions | ||
""" | ||
|
||
import numpy as np | ||
from sklearn.base import TransformerMixin, BaseEstimator | ||
from sklearn.exceptions import NotFittedError | ||
|
||
from gensim import models | ||
from gensim.sklearn_integration import base_sklearn_wrapper | ||
|
||
|
||
class SklRpModel(base_sklearn_wrapper.BaseSklearnWrapper, TransformerMixin, BaseEstimator): | ||
""" | ||
Base RP module | ||
""" | ||
|
||
def __init__(self, id2word=None, num_topics=300): | ||
""" | ||
Sklearn wrapper for RP model. Class derived from gensim.models.RpModel. | ||
""" | ||
self.gensim_model = None | ||
self.id2word = id2word | ||
self.num_topics = num_topics | ||
|
||
def get_params(self, deep=True): | ||
""" | ||
Returns all parameters as dictionary. | ||
""" | ||
return {"id2word": self.id2word, "num_topics": self.num_topics} | ||
|
||
def set_params(self, **parameters): | ||
""" | ||
Set all parameters. | ||
""" | ||
super(SklRpModel, self).set_params(**parameters) | ||
|
||
def fit(self, X, y=None): | ||
""" | ||
Fit the model according to the given training data. | ||
Calls gensim.models.RpModel | ||
""" | ||
self.gensim_model = models.RpModel(corpus=X, id2word=self.id2word, num_topics=self.num_topics) | ||
return self | ||
|
||
def transform(self, docs): | ||
""" | ||
Take documents/corpus as input. | ||
Return RP representation of the input documents/corpus. | ||
The input `docs` can correspond to multiple documents like : [ [(0, 1.0), (1, 1.0), (2, 1.0)], [(0, 1.0), (3, 1.0), (4, 1.0), (5, 1.0), (6, 1.0), (7, 1.0)] ] | ||
or a single document like : [(0, 1.0), (1, 1.0), (2, 1.0)] | ||
""" | ||
if self.gensim_model is None: | ||
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.") | ||
|
||
# The input as array of array | ||
check = lambda x: [x] if isinstance(x[0], tuple) else x | ||
docs = check(docs) | ||
X = [[] for _ in range(0, len(docs))] | ||
|
||
for k, v in enumerate(docs): | ||
transformed_doc = self.gensim_model[v] | ||
probs_docs = list(map(lambda x: x[1], transformed_doc)) | ||
# Everything should be equal in length | ||
if len(probs_docs) != self.num_topics: | ||
probs_docs.extend([1e-12] * (self.num_topics - len(probs_docs))) | ||
X[k] = probs_docs | ||
|
||
return np.reshape(np.array(X), (len(docs), self.num_topics)) | ||
|
||
def partial_fit(self, X): | ||
raise NotImplementedError("'partial_fit' has not been implemented for SklRpModel") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,10 +15,11 @@ | |
except ImportError: | ||
raise unittest.SkipTest("Test requires scikit-learn to be installed, which is not available") | ||
|
||
from gensim.sklearn_integration.sklearn_wrapper_gensim_rpmodel import SklRpModel | ||
from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklLdaModel | ||
from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklLsiModel | ||
from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel | ||
from gensim.corpora import Dictionary | ||
from gensim.corpora import mmcorpus, Dictionary | ||
from gensim import matutils | ||
|
||
module_path = os.path.dirname(__file__) # needed because sample data files are located in the same folder | ||
|
@@ -328,5 +329,72 @@ def testModelNotFitted(self): | |
self.assertRaises(NotFittedError, ldaseq_wrapper.transform, doc) | ||
|
||
|
||
class TestSklRpModelWrapper(unittest.TestCase): | ||
def setUp(self): | ||
numpy.random.seed(13) | ||
self.model = SklRpModel(num_topics=2) | ||
self.corpus = mmcorpus.MmCorpus(datapath('testcorpus.mm')) | ||
self.model.fit(self.corpus) | ||
|
||
def testTransform(self): | ||
# tranform two documents | ||
docs = [] | ||
docs.append(list(self.corpus)[0]) | ||
docs.append(list(self.corpus)[1]) | ||
matrix = self.model.transform(docs) | ||
self.assertEqual(matrix.shape[0], 2) | ||
self.assertEqual(matrix.shape[1], self.model.num_topics) | ||
|
||
# tranform one document | ||
doc = list(self.corpus)[0] | ||
matrix = self.model.transform(doc) | ||
self.assertEqual(matrix.shape[0], 1) | ||
self.assertEqual(matrix.shape[1], self.model.num_topics) | ||
|
||
def testSetGetParams(self): | ||
# updating only one param | ||
self.model.set_params(num_topics=3) | ||
model_params = self.model.get_params() | ||
self.assertEqual(model_params["num_topics"], 3) | ||
|
||
def testPipeline(self): | ||
numpy.random.seed(0) # set fixed seed to get similar values everytime | ||
model = SklRpModel(num_topics=2) | ||
with open(datapath('mini_newsgroup'), 'rb') as f: | ||
compressed_content = f.read() | ||
uncompressed_content = codecs.decode(compressed_content, 'zlib_codec') | ||
cache = pickle.loads(uncompressed_content) | ||
data = cache | ||
id2word = Dictionary(map(lambda x: x.split(), data.data)) | ||
corpus = [id2word.doc2bow(i.split()) for i in data.data] | ||
numpy.random.mtrand.RandomState(1) # set seed for getting same result | ||
clf = linear_model.LogisticRegression(penalty='l2', C=0.1) | ||
text_rp = Pipeline((('features', model,), ('classifier', clf))) | ||
text_rp.fit(corpus, data.target) | ||
score = text_rp.score(corpus, data.target) | ||
self.assertGreater(score, 0.40) | ||
|
||
def testPersistence(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as LdaSeq There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Done. |
||
model_dump = pickle.dumps(self.model) | ||
model_load = pickle.loads(model_dump) | ||
|
||
doc = list(self.corpus)[0] | ||
loaded_transformed_vecs = model_load.transform(doc) | ||
|
||
# sanity check for transformation operation | ||
self.assertEqual(loaded_transformed_vecs.shape[0], 1) | ||
self.assertEqual(loaded_transformed_vecs.shape[1], model_load.num_topics) | ||
|
||
# comparing the original and loaded models | ||
original_transformed_vecs = self.model.transform(doc) | ||
passed = numpy.allclose(sorted(loaded_transformed_vecs), sorted(original_transformed_vecs), atol=1e-1) | ||
self.assertTrue(passed) | ||
|
||
def testModelNotFitted(self): | ||
rpmodel_wrapper = SklRpModel(num_topics=2) | ||
doc = list(self.corpus)[0] | ||
self.assertRaises(NotFittedError, rpmodel_wrapper.transform, doc) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as LdaSeq
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Done.