-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adding sklearn wrapper for LDA code #932
Changes from 5 commits
08f417c
61a6f8c
66be324
cffa95b
10badc6
62a4d2f
b7eff2d
2a193fd
a32f8dc
a048ddc
ac1d28e
0d6cc0a
5d8c1a6
894784c
7a5ca4b
b35baba
13a136d
682f045
9fda951
380ea5f
e2485d4
3015896
a76eda4
97c1530
20a63ac
c0b2c5c
bd656a8
d749ba0
21119c5
14f984b
a3895b5
f832737
bc352a0
7cc39da
0ba233c
e23a8a4
041a32e
e7120f0
8a0950d
bd8bced
bb5872b
777576e
e50c3f9
e521269
51931fa
7ba30d6
82d1fdc
4f3441e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Using wrappers for Scikit learn API" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"This tutorial is about using gensim models as a part of your scikit learn workflow with the help of wrappers found at ```gensim.sklearn_integration.base```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The wrapper available (as of now) are :\n", | ||
"* LdaModel (```gensim.sklearn_integration.base.LdaModel```),which implements gensim's ```LdaModel``` in a scikit-learn interface" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please update ipynb with new names of .py file and of the class |
||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### LdaModel" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To use LdaModel begin with importing LdaModel wrapper" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.sklearn_integration.base import LdaModel" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Next we will create a dummy set of texts and convert it into a corpus" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add the examples to ipynb from https://gist.github.com/AadityaJ/c98da3d01f76f068242c17b5e1593973 |
||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from gensim.corpora import mmcorpus, Dictionary\n", | ||
"texts = [['human', 'interface', 'computer'],\n", | ||
" ['survey', 'user', 'computer', 'system', 'response', 'time'],\n", | ||
" ['eps', 'user', 'interface', 'system'],\n", | ||
" ['system', 'human', 'system', 'eps'],\n", | ||
" ['user', 'response', 'time'],\n", | ||
" ['trees'],\n", | ||
" ['graph', 'trees'],\n", | ||
" ['graph', 'minors', 'trees'],\n", | ||
" ['graph', 'minors', 'survey']]\n", | ||
"dictionary = Dictionary(texts)\n", | ||
"corpus = [dictionary.doc2bow(text) for text in texts]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Then to run the LdaModel on it" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[(0, u'0.271*system + 0.181*eps + 0.181*interface + 0.181*human + 0.091*computer + 0.091*user + 0.001*trees + 0.001*graph + 0.001*time + 0.001*minors'), (1, u'0.166*graph + 0.166*trees + 0.111*user + 0.111*survey + 0.111*response + 0.111*minors + 0.111*time + 0.056*computer + 0.056*system + 0.001*human')]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model=LdaModel(n_topics=2,id2word=dictionary,n_iter=20, random_state=1)\n", | ||
"model.fit(corpus)\n", | ||
"print model.print_topics(2)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 2", | ||
"language": "python", | ||
"name": "python2" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""scikit learn wrapper for gensim | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing file preamble (encoding, author, license etc). |
||
Contains various gensim based implementations | ||
which match with scikit-learn standards . | ||
See [1] for complete set of conventions. | ||
[1] http://scikit-learn.org/stable/developers/ | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <[email protected]> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
# | ||
""" | ||
scikit learn interface for gensim for easy use of gensim with scikit-learn | ||
""" | ||
import numpy as np | ||
import gensim.models.ldamodel | ||
|
||
|
||
class BaseClass(object): | ||
def __init__(self): | ||
"""init | ||
base class to be always inherited | ||
to be used in the future | ||
""" | ||
def run(self): # to test | ||
return np.array([0, 0, 0]) | ||
|
||
|
||
class LdaModel(object): | ||
""" | ||
Base LDA module | ||
""" | ||
def __init__(self, n_topics=5, n_iter=2000, alpha=0.1, eta=0.01, random_state=None, | ||
refresh=10,lda_model=None,id2word=None,passes=20,ex=None): | ||
""" | ||
base LDA code . Uses mapper function | ||
n_topics : num_topics | ||
.fit : init call // corpus not used | ||
//none : id2word | ||
n_iter : passes // assumed | ||
random_state : random_state | ||
alpha : alpha | ||
eta : eta | ||
refresh : update_every | ||
id2word: id2word | ||
""" | ||
self.n_topics = n_topics | ||
self.n_iter = n_iter | ||
self.alpha = alpha | ||
self.eta = eta | ||
self.random_state = random_state | ||
self.refresh = refresh | ||
self.id2word=id2word | ||
self.passes=passes | ||
# use lda_model variable as object | ||
self.lda_model = lda_model | ||
# perform appropriate checks | ||
if alpha <= 0: | ||
raise ValueError("alpha value must be larger than zero") | ||
if eta <= 0: | ||
raise ValueError("eta value must be larger than zero") | ||
|
||
def get_params(self, deep=True): | ||
if deep: | ||
return {"alpha": self.alpha, "n_iter": self.n_iter,"eta":self.eta,"random_state":self.random_state,"lda_model":self.lda_model,"id2word":self.id2word,"passes":self.passes} | ||
|
||
def set_params(self, **parameters): | ||
for parameter, value in parameters.items(): | ||
self.setattr(parameter, value) | ||
return self | ||
|
||
def fit(self,X,y=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PEP8: Spaces after commas. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestions have been incorporated. thanks |
||
""" | ||
call gensim.model.LdaModel from this | ||
// todo: convert fit and relevant,corpus still requires gensim preprocessing | ||
calling : | ||
>>>gensim.models.LdaModel(corpus=corpus,num_topics=n_topics,id2word=None,passes=n_iter,update_every=refresh,alpha=alpha,iterations=n_iter,eta=eta,random_state=random_state) | ||
""" | ||
if X is None: | ||
raise AttributeError("Corpus defined as none") | ||
self.lda_model = gensim.models.LdaModel(corpus=X,num_topics=self.n_topics, id2word=self.id2word, passes=self.passes, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code style: we don't use vertical indent in gensim. Use hanging indent. (plus, spaces & commas again) |
||
update_every=self.refresh,alpha=self.alpha, iterations=self.n_iter, | ||
eta=self.eta,random_state=self.random_state) | ||
return self.lda_model | ||
|
||
def print_topics(self,n_topics=20,num_words=20,log=True): | ||
""" | ||
print all the topics | ||
using the object lda_model | ||
""" | ||
return self.lda_model.show_topics(num_topics=n_topics,num_words=num_words,log=log) | ||
|
||
def transform(self, bow, minimum_probability=None, minimum_phi_value=None, per_word_topics=False): | ||
""" | ||
takes as an input a new document (bow) and | ||
Return topic distribution for the given document bow, as a list of (topic_id, topic_probability) 2-tuples. | ||
""" | ||
return self.lda_model.get_document_topics(bow,minimum_probability=minimum_probability,minimum_phi_value=minimum_phi_value, | ||
per_word_topics=per_word_topics) | ||
# might need to do more | ||
def get_term_topics(self,wordid,minimum_probability=None): | ||
""" | ||
returns the most likely topic associated with a particular word | ||
use wordid or simply pass the word itself | ||
""" | ||
return self.lda_model.get_term_topics(wordid,minimum_probability=minimum_probability) | ||
|
||
def get_topic_terms(self,topicid,topn=10): | ||
""" | ||
return a tuple of (wordid,probability) for given topic | ||
topn can be used to restrict | ||
""" | ||
return self.lda_model.get_topic_terms(topicid=topicid,topn=topn) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import six | ||
import unittest | ||
|
||
from gensim.sklearn_integration import base | ||
from gensim.corpora import Dictionary | ||
texts = [['human', 'interface', 'computer'], | ||
['survey', 'user', 'computer', 'system', 'response', 'time'], | ||
['eps', 'user', 'interface', 'system'], | ||
['system', 'human', 'system', 'eps'], | ||
['user', 'response', 'time'], | ||
['trees'], | ||
['graph', 'trees'], | ||
['graph', 'minors', 'trees'], | ||
['graph', 'minors', 'survey']] | ||
dictionary = Dictionary(texts) | ||
corpus = [dictionary.doc2bow(text) for text in texts] | ||
|
||
|
||
class TestLdaModel: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All Python classes should inherit from object (new-style classes). |
||
def __init__(self): | ||
self.model=base.LdaModel(id2word=dictionary,n_topics=2,passes=100) | ||
self.model.fit(corpus) | ||
|
||
def testPrintTopic(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a partial_fit test |
||
topic = self.model.print_topics(2) | ||
|
||
for k, v in topic: | ||
self.assertTrue(isinstance(k, six.string_types)) | ||
self.assertTrue(isinstance(v, float)) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing newline at the end of file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please resolve merge conflicts. Only one line should be added to changelog. Remove extra 2 lines about other changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please merge in develop branch to remove merge conflicts