Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Sklearn wrapper for RandomProjections Model #1395

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0c5bcb0
created new file for rpmodel_sklearn_wrapper
chinmayapancholi13 Jun 6, 2017
0810428
updated get_params, set_params functions
chinmayapancholi13 Jun 6, 2017
d67f047
correction in calling init function
chinmayapancholi13 Jun 7, 2017
a9ce401
added fit, transform, partial_fit function
chinmayapancholi13 Jun 7, 2017
05ad743
added tests for Rp model's sklearn wrapper
chinmayapancholi13 Jun 7, 2017
f1b9c4a
minor correction in docstring in LDA and LSI models
chinmayapancholi13 Jun 7, 2017
8696e54
added newline before class definition (PEP8)
chinmayapancholi13 Jun 8, 2017
fe2f947
removed 'corpus' from 'init' and set 'corpus' in 'fit'
chinmayapancholi13 Jun 8, 2017
7317173
updated docstring for 'fit' function
chinmayapancholi13 Jun 8, 2017
692be88
refactored code to use 'self.model'
chinmayapancholi13 Jun 13, 2017
a2ec746
code style changes
chinmayapancholi13 Jun 13, 2017
954715e
refactored wrapper and tests
chinmayapancholi13 Jun 14, 2017
6c3b819
removed 'self.corpus' attribute and refactored slightly
chinmayapancholi13 Jun 14, 2017
aee04ff
updated 'self.__model' to 'self.gensim_model'
chinmayapancholi13 Jun 15, 2017
a73dacc
updated test data
chinmayapancholi13 Jun 15, 2017
da602d9
updated 'fit' and 'transform' methods
chinmayapancholi13 Jun 15, 2017
c1087ac
updated 'testTransform' test
chinmayapancholi13 Jun 15, 2017
00f5336
PEP8 change
chinmayapancholi13 Jun 15, 2017
376959d
updated 'testTransform' test
chinmayapancholi13 Jun 15, 2017
9c888d6
added 'NotFittedError' in 'transform' function
chinmayapancholi13 Jun 16, 2017
373c36c
added 'testPersistence' and 'testModelNotFitted' tests
chinmayapancholi13 Jun 16, 2017
f3c3601
added input 'docs' description in 'transform' function
chinmayapancholi13 Jun 16, 2017
ab90b68
added 'testPipeline' test
chinmayapancholi13 Jun 16, 2017
928c7f2
replaced 'text_lda' variable with 'text_rp'
chinmayapancholi13 Jun 18, 2017
cf13c9a
updated 'testPersistence' test
chinmayapancholi13 Jun 19, 2017
cde12f2
set fixed seed in 'testPipeline' test
chinmayapancholi13 Jun 19, 2017
26cd2df
Merge branch 'develop' into rp_wrapper_scikitlearn
menshikh-iv Jun 20, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions gensim/sklearn_integration/sklearn_wrapper_gensim_rpmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <[email protected]>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Scikit learn interface for gensim for easy use of gensim with scikit-learn
Follows scikit-learn API conventions
"""

import numpy as np
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.exceptions import NotFittedError

from gensim import models
from gensim.sklearn_integration import base_sklearn_wrapper


class SklRpModel(base_sklearn_wrapper.BaseSklearnWrapper, TransformerMixin, BaseEstimator):
"""
Base RP module
"""

def __init__(self, id2word=None, num_topics=300):
"""
Sklearn wrapper for RP model. Class derived from gensim.models.RpModel.
"""
self.gensim_model = None
self.id2word = id2word
self.num_topics = num_topics

def get_params(self, deep=True):
"""
Returns all parameters as dictionary.
"""
return {"id2word": self.id2word, "num_topics": self.num_topics}

def set_params(self, **parameters):
"""
Set all parameters.
"""
super(SklRpModel, self).set_params(**parameters)

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Calls gensim.models.RpModel
"""
self.gensim_model = models.RpModel(corpus=X, id2word=self.id2word, num_topics=self.num_topics)
return self

def transform(self, docs):
"""
Take documents/corpus as input.
Return RP representation of the input documents/corpus.
The input `docs` can correspond to multiple documents like : [ [(0, 1.0), (1, 1.0), (2, 1.0)], [(0, 1.0), (3, 1.0), (4, 1.0), (5, 1.0), (6, 1.0), (7, 1.0)] ]
or a single document like : [(0, 1.0), (1, 1.0), (2, 1.0)]
"""
if self.gensim_model is None:
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

# The input as array of array
check = lambda x: [x] if isinstance(x[0], tuple) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

for k, v in enumerate(docs):
transformed_doc = self.gensim_model[v]
probs_docs = list(map(lambda x: x[1], transformed_doc))
# Everything should be equal in length
if len(probs_docs) != self.num_topics:
probs_docs.extend([1e-12] * (self.num_topics - len(probs_docs)))
X[k] = probs_docs

return np.reshape(np.array(X), (len(docs), self.num_topics))

def partial_fit(self, X):
raise NotImplementedError("'partial_fit' has not been implemented for SklRpModel")
70 changes: 69 additions & 1 deletion gensim/test/test_sklearn_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
except ImportError:
raise unittest.SkipTest("Test requires scikit-learn to be installed, which is not available")

from gensim.sklearn_integration.sklearn_wrapper_gensim_rpmodel import SklRpModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklLdaModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklLsiModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel
from gensim.corpora import Dictionary
from gensim.corpora import mmcorpus, Dictionary
from gensim import matutils

module_path = os.path.dirname(__file__) # needed because sample data files are located in the same folder
Expand Down Expand Up @@ -328,5 +329,72 @@ def testModelNotFitted(self):
self.assertRaises(NotFittedError, ldaseq_wrapper.transform, doc)


class TestSklRpModelWrapper(unittest.TestCase):
def setUp(self):
numpy.random.seed(13)
self.model = SklRpModel(num_topics=2)
self.corpus = mmcorpus.MmCorpus(datapath('testcorpus.mm'))
self.model.fit(self.corpus)

def testTransform(self):
# tranform two documents
docs = []
docs.append(list(self.corpus)[0])
docs.append(list(self.corpus)[1])
matrix = self.model.transform(docs)
self.assertEqual(matrix.shape[0], 2)
self.assertEqual(matrix.shape[1], self.model.num_topics)

# tranform one document
doc = list(self.corpus)[0]
matrix = self.model.transform(doc)
self.assertEqual(matrix.shape[0], 1)
self.assertEqual(matrix.shape[1], self.model.num_topics)

def testSetGetParams(self):
# updating only one param
self.model.set_params(num_topics=3)
model_params = self.model.get_params()
self.assertEqual(model_params["num_topics"], 3)

def testPipeline(self):
numpy.random.seed(0) # set fixed seed to get similar values everytime
model = SklRpModel(num_topics=2)
with open(datapath('mini_newsgroup'), 'rb') as f:
compressed_content = f.read()
uncompressed_content = codecs.decode(compressed_content, 'zlib_codec')
cache = pickle.loads(uncompressed_content)
data = cache
id2word = Dictionary(map(lambda x: x.split(), data.data))
corpus = [id2word.doc2bow(i.split()) for i in data.data]
numpy.random.mtrand.RandomState(1) # set seed for getting same result
clf = linear_model.LogisticRegression(penalty='l2', C=0.1)
text_rp = Pipeline((('features', model,), ('classifier', clf)))
text_rp.fit(corpus, data.target)
score = text_rp.score(corpus, data.target)
self.assertGreater(score, 0.40)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as LdaSeq

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Done.


def testPersistence(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as LdaSeq

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Done.

model_dump = pickle.dumps(self.model)
model_load = pickle.loads(model_dump)

doc = list(self.corpus)[0]
loaded_transformed_vecs = model_load.transform(doc)

# sanity check for transformation operation
self.assertEqual(loaded_transformed_vecs.shape[0], 1)
self.assertEqual(loaded_transformed_vecs.shape[1], model_load.num_topics)

# comparing the original and loaded models
original_transformed_vecs = self.model.transform(doc)
passed = numpy.allclose(sorted(loaded_transformed_vecs), sorted(original_transformed_vecs), atol=1e-1)
self.assertTrue(passed)

def testModelNotFitted(self):
rpmodel_wrapper = SklRpModel(num_topics=2)
doc = list(self.corpus)[0]
self.assertRaises(NotFittedError, rpmodel_wrapper.transform, doc)


if __name__ == '__main__':
unittest.main()