diff --git a/docs/notebooks/sklearn_wrapper.ipynb b/docs/notebooks/sklearn_wrapper.ipynb index 0d28429ecf..e98047dedc 100644 --- a/docs/notebooks/sklearn_wrapper.ipynb +++ b/docs/notebooks/sklearn_wrapper.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This tutorial is about using gensim models as a part of your scikit learn workflow with the help of wrappers found at ```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel```" + "This tutorial is about using gensim models as a part of your scikit learn workflow with the help of wrappers found at ```gensim.sklearn_integration```" ] }, { @@ -19,7 +19,9 @@ "metadata": {}, "source": [ "The wrapper available (as of now) are :\n", - "* LdaModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel.SklearnWrapperLdaModel```),which implements gensim's ```LdaModel``` in a scikit-learn interface" + "* LdaModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel.SklearnWrapperLdaModel```),which implements gensim's ```LdaModel``` in a scikit-learn interface\n", + "\n", + "* LsiModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_lsiModel.SklearnWrapperLsiModel```),which implements gensim's ```LsiModel``` in a scikit-learn interface" ] }, { @@ -38,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 1, "metadata": { "collapsed": false }, @@ -56,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 2, "metadata": { "collapsed": true }, @@ -85,7 +87,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 3, "metadata": { "collapsed": false }, @@ -111,7 +113,7 @@ " [ 0.84210373, 0.15789627]])" ] }, - "execution_count": 22, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -129,7 +131,7 @@ "collapsed": true }, "source": [ - "### Integration with Sklearn" + "#### Integration with Sklearn" ] }, { @@ -141,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 4, "metadata": { "collapsed": false }, @@ -157,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 5, "metadata": { "collapsed": false }, @@ -179,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 6, "metadata": { "collapsed": false }, @@ -202,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 7, "metadata": { "collapsed": false }, @@ -211,18 +213,18 @@ "data": { "text/plain": [ "[(0,\n", - " u'0.085*\"abroad\" + 0.053*\"ciphertext\" + 0.042*\"arithmetic\" + 0.037*\"facts\" + 0.031*\"courtesy\" + 0.025*\"amolitor\" + 0.023*\"argue\" + 0.021*\"asking\" + 0.020*\"agree\" + 0.018*\"classified\"'),\n", + " u'0.025*\"456\" + 0.021*\"argue\" + 0.016*\"bitnet\" + 0.015*\"beastmaster\" + 0.014*\"cryptography\" + 0.013*\"false\" + 0.012*\"digex\" + 0.011*\"cover\" + 0.011*\"classified\" + 0.010*\"disk\"'),\n", " (1,\n", - " u'0.098*\"asking\" + 0.075*\"cryptography\" + 0.068*\"abroad\" + 0.033*\"456\" + 0.025*\"argue\" + 0.022*\"bitnet\" + 0.017*\"false\" + 0.014*\"digex\" + 0.014*\"effort\" + 0.013*\"disk\"'),\n", + " u'0.142*\"abroad\" + 0.113*\"asking\" + 0.088*\"cryptography\" + 0.044*\"ciphertext\" + 0.043*\"arithmetic\" + 0.032*\"courtesy\" + 0.030*\"facts\" + 0.021*\"argue\" + 0.019*\"amolitor\" + 0.018*\"agree\"'),\n", " (2,\n", - " u'0.023*\"accurate\" + 0.021*\"corporate\" + 0.013*\"clark\" + 0.012*\"chance\" + 0.009*\"consideration\" + 0.008*\"authentication\" + 0.008*\"dawson\" + 0.008*\"candidates\" + 0.008*\"basically\" + 0.008*\"assess\"'),\n", + " u'0.034*\"certain\" + 0.027*\"69\" + 0.025*\"book\" + 0.025*\"demand\" + 0.024*\"87\" + 0.024*\"cracking\" + 0.021*\"farm\" + 0.019*\"fierkelab\" + 0.015*\"face\" + 0.011*\"abroad\"'),\n", " (3,\n", - " u'0.016*\"cryptography\" + 0.007*\"evans\" + 0.006*\"considering\" + 0.006*\"forgot\" + 0.006*\"built\" + 0.005*\"constitutional\" + 0.005*\"fly\" + 0.004*\"cellular\" + 0.004*\"computed\" + 0.004*\"digitized\"'),\n", + " u'0.017*\"decipher\" + 0.017*\"example\" + 0.016*\"cases\" + 0.016*\"follow\" + 0.008*\"considering\" + 0.006*\"forgot\" + 0.006*\"cellular\" + 0.005*\"evans\" + 0.005*\"computed\" + 0.005*\"cia\"'),\n", " (4,\n", - " u'0.028*\"certain\" + 0.022*\"69\" + 0.021*\"book\" + 0.020*\"demand\" + 0.020*\"cracking\" + 0.020*\"87\" + 0.017*\"farm\" + 0.017*\"fierkelab\" + 0.015*\"face\" + 0.009*\"constitutional\"')]" + " u'0.022*\"accurate\" + 0.021*\"corporate\" + 0.013*\"chance\" + 0.012*\"clark\" + 0.009*\"consideration\" + 0.009*\"candidates\" + 0.008*\"dawson\" + 0.008*\"authentication\" + 0.008*\"assess\" + 0.008*\"attempt\"')]" ] }, - "execution_count": 26, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -239,12 +241,12 @@ "collapsed": true }, "source": [ - "### Example for Using Grid Search" + "#### Example for Using Grid Search" ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 8, "metadata": { "collapsed": false }, @@ -256,7 +258,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 9, "metadata": { "collapsed": true }, @@ -269,7 +271,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 10, "metadata": { "collapsed": false }, @@ -280,16 +282,16 @@ "GridSearchCV(cv=5, error_score='raise',\n", " estimator=SklearnWrapperLdaModel(alpha='symmetric', chunksize=2000, corpus=None,\n", " decay=0.5, eta=None, eval_every=10, gamma_threshold=0.001,\n", - " id2word=,\n", + " id2word=,\n", " iterations=50, minimum_probability=0.01, num_topics=5,\n", " offset=1.0, passes=20, random_state=None, update_every=1),\n", " fit_params={}, iid=True, n_jobs=1,\n", " param_grid={'num_topics': (2, 3, 5, 10), 'iterations': (1, 20, 50)},\n", " pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n", - " scoring=, verbose=0)" + " scoring=, verbose=0)" ] }, - "execution_count": 32, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -303,7 +305,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 11, "metadata": { "collapsed": false }, @@ -311,10 +313,10 @@ { "data": { "text/plain": [ - "{'iterations': 50, 'num_topics': 3}" + "{'iterations': 20, 'num_topics': 3}" ] }, - "execution_count": 33, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -327,14 +329,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Example of Using Pipeline" + "#### Example of Using Pipeline" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 12, "metadata": { - "collapsed": true + "collapsed": false }, "outputs": [], "source": [ @@ -350,7 +352,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 13, "metadata": { "collapsed": false }, @@ -362,7 +364,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 14, "metadata": { "collapsed": false }, @@ -396,6 +398,76 @@ "print_features_pipe(pipe, id2word.values())\n", "print pipe.score(corpus, data.target)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### LsiModel" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use LsiModel begin with importing LsiModel wrapper" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklearnWrapperLsiModel" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Example of Using Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 0.13652819 0.00383696 0.02635504 -0.08454895 -0.02356143 0.60020084\n", + " 1.07026252 -0.04072257 0.43732847 0.54913549 -0.20242834 -0.21855402\n", + " -1.30546283 -0.08690711 0.17606255]\n", + "Positive features: 01101001B:1.07 comp.org.eff.talk.:0.60 red@redpoll.neoucom.edu:0.55 circuitry:0.44 >Pat:0.18 Fame.:0.14 Fame,:0.03 considered,:0.00\n", + "Negative features: internet...:-1.31 trawling:-0.22 hanging:-0.20 dome.:-0.09 Keach:-0.08 *best*:-0.04 comp.org.eff.talk,:-0.02\n", + "0.865771812081\n" + ] + } + ], + "source": [ + "model=SklearnWrapperLsiModel(num_topics=15, id2word=id2word)\n", + "clf=linear_model.LogisticRegression(penalty='l2', C=0.1) #l2 penalty used\n", + "pipe = Pipeline((('features', model,), ('classifier', clf)))\n", + "pipe.fit(corpus, data.target)\n", + "print_features_pipe(pipe, id2word.values())\n", + "print pipe.score(corpus, data.target)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_lsimodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_lsimodel.py new file mode 100644 index 0000000000..753cbaf899 --- /dev/null +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_lsimodel.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2011 Radim Rehurek +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html +# +""" +Scikit learn interface for gensim for easy use of gensim with scikit-learn +Follows scikit-learn API conventions +""" +import numpy as np + +from gensim import models +from gensim import matutils +from scipy import sparse +from sklearn.base import TransformerMixin, BaseEstimator + +class SklearnWrapperLsiModel(models.LsiModel, TransformerMixin, BaseEstimator): + """ + Base LSI module + """ + + def __init__(self, corpus=None, num_topics=200, id2word=None, chunksize=20000, + decay=1.0, onepass=True, power_iters=2, extra_samples=100): + """ + Sklearn wrapper for LSI model. Class derived from gensim.model.LsiModel. + """ + self.corpus = corpus + self.num_topics = num_topics + self.id2word = id2word + self.chunksize = chunksize + self.decay = decay + self.onepass = onepass + self.extra_samples = extra_samples + self.power_iters = power_iters + + # if 'fit' function is not used, then 'corpus' is given in init + if self.corpus: + models.LsiModel.__init__(self, corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, chunksize=self.chunksize, + decay=self.decay, onepass=self.onepass, power_iters=self.power_iters, extra_samples=self.extra_samples) + + def get_params(self, deep=True): + """ + Returns all parameters as dictionary. + """ + return {"corpus": self.corpus, "num_topics": self.num_topics, "id2word": self.id2word, + "chunksize": self.chunksize, "decay": self.decay, "onepass": self.onepass, + "extra_samples": self.extra_samples, "power_iters": self.power_iters} + + def set_params(self, **parameters): + """ + Set all parameters. + """ + for parameter, value in parameters.items(): + self.parameter = value + return self + + def fit(self, X, y=None): + """ + For fitting corpus into the class object. + Calls gensim.model.LsiModel: + >>>gensim.models.LsiModel(corpus=corpus, num_topics=num_topics, id2word=id2word, chunksize=chunksize, decay=decay, onepass=onepass, power_iters=power_iters, extra_samples=extra_samples) + """ + if sparse.issparse(X): + self.corpus = matutils.Sparse2Corpus(X) + else: + self.corpus = X + + models.LsiModel.__init__(self, corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, chunksize=self.chunksize, + decay=self.decay, onepass=self.onepass, power_iters=self.power_iters, extra_samples=self.extra_samples) + return self + + def transform(self, docs): + """ + Takes a list of documents as input ('docs'). + Returns a matrix of topic distribution for the given document bow, where a_ij + indicates (topic_i, topic_probability_j). + """ + # The input as array of array + check = lambda x: [x] if isinstance(x[0], tuple) else x + docs = check(docs) + X = [[] for i in range(0,len(docs))]; + for k,v in enumerate(docs): + doc_topics = self[v] + probs_docs = list(map(lambda x: x[1], doc_topics)) + # Everything should be equal in length + if len(probs_docs) != self.num_topics: + probs_docs.extend([1e-12]*(self.num_topics - len(probs_docs))) + X[k] = probs_docs + probs_docs = [] + return np.reshape(np.array(X), (len(docs), self.num_topics)) + + def partial_fit(self, X): + """ + Train model over X. + """ + if sparse.issparse(X): + X = matutils.Sparse2Corpus(X) + self.add_documents(corpus=X) \ No newline at end of file diff --git a/gensim/test/test_sklearn_integration.py b/gensim/test/test_sklearn_integration.py index 3a6401962b..2f5497550a 100644 --- a/gensim/test/test_sklearn_integration.py +++ b/gensim/test/test_sklearn_integration.py @@ -11,6 +11,7 @@ from sklearn.datasets import load_files from sklearn import linear_model from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklearnWrapperLdaModel +from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklearnWrapperLsiModel from gensim.corpora import Dictionary from gensim import matutils @@ -55,7 +56,7 @@ def testTransform(self): X = self.model.transform(bow) self.assertTrue(X.shape[0], 3) self.assertTrue(X.shape[1], self.model.num_topics) - + def testGetTopicDist(self): texts_new = ['graph','eulerian'] bow = self.model.id2word.doc2bow(texts_new) @@ -97,7 +98,7 @@ def testPipeline(self): compressed_content = f.read() uncompressed_content = codecs.decode(compressed_content, 'zlib_codec') cache = pickle.loads(uncompressed_content) - data = cache + data = cache id2word=Dictionary(map(lambda x : x.split(), data.data)) corpus = [id2word.doc2bow(i.split()) for i in data.data] rand = numpy.random.mtrand.RandomState(1) # set seed for getting same result @@ -107,5 +108,55 @@ def testPipeline(self): score = text_lda.score(corpus, data.target) self.assertGreater(score, 0.50) +class TestSklearnLSIWrapper(unittest.TestCase): + def setUp(self): + self.model = SklearnWrapperLsiModel(id2word=dictionary, num_topics=2) + self.model.fit(corpus) + + def testModelSanity(self): + topic = self.model.print_topics(2) + for k, v in topic: + self.assertTrue(isinstance(v, six.string_types)) + self.assertTrue(isinstance(k, int)) + + def testTransform(self): + texts_new = ['graph','eulerian'] + bow = self.model.id2word.doc2bow(texts_new) + X = self.model.transform(bow) + self.assertTrue(X.shape[0], 1) + self.assertTrue(X.shape[1], self.model.num_topics) + texts_new = [['graph','eulerian'],['server', 'flow'], ['path', 'system']] + bow = [] + for i in texts_new: + bow.append(self.model.id2word.doc2bow(i)) + X = self.model.transform(bow) + self.assertTrue(X.shape[0], 3) + self.assertTrue(X.shape[1], self.model.num_topics) + + def testPartialFit(self): + for i in range(10): + self.model.partial_fit(X=corpus) # fit against the model again + doc=list(corpus)[0] # transform only the first document + transformed = self.model[doc] + transformed_approx = matutils.sparse2full(transformed, 2) # better approximation + expected=[1.39, 0.0] + passed = numpy.allclose(sorted(transformed_approx), sorted(expected), atol=1e-1) + self.assertTrue(passed) + + def testPipeline(self): + model = SklearnWrapperLsiModel(num_topics=2) + with open(datapath('mini_newsgroup'),'rb') as f: + compressed_content = f.read() + uncompressed_content = codecs.decode(compressed_content, 'zlib_codec') + cache = pickle.loads(uncompressed_content) + data = cache + id2word=Dictionary(map(lambda x : x.split(), data.data)) + corpus = [id2word.doc2bow(i.split()) for i in data.data] + clf=linear_model.LogisticRegression(penalty='l2', C=0.1) + text_lda = Pipeline((('features', model,), ('classifier', clf))) + text_lda.fit(corpus, data.target) + score = text_lda.score(corpus, data.target) + self.assertGreater(score, 0.50) + if __name__ == '__main__': unittest.main()