Skip to content

Commit

Permalink
Allow indexing with np.int64 in doc2vec. Fix #1231 (#1254)
Browse files Browse the repository at this point in the history
* doc2vec: allow indexing with np.int64 -- fixes #1231

* doc2vec: use assertEqual instead of assertEquals

* Do integer checks using both `six.integer_types` and `np.integer`

* Add more tests for np.int64 indexing
  • Loading branch information
bogdanteleaga authored and tmylk committed May 2, 2017
1 parent b6c3c2d commit 50a18b9
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 21 deletions.
4 changes: 2 additions & 2 deletions gensim/corpora/indexedcorpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

import logging
import shelve
import six

import numpy

Expand Down Expand Up @@ -124,7 +124,7 @@ def __getitem__(self, docno):

if isinstance(docno, (slice, list, numpy.ndarray)):
return utils.SlicedCorpus(self, docno)
elif isinstance(docno, (int, numpy.integer)):
elif isinstance(docno, six.integer_types + (numpy.integer,)):
return self.docbyoffset(self.index[docno])
else:
raise ValueError('Unrecognised value for docno, use either a single integer, a slice or a numpy.ndarray')
Expand Down
7 changes: 1 addition & 6 deletions gensim/models/atmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,8 @@
# are included in the code where this is the case, for example in the log_perplexity
# and do_estep methods.

import pdb
from pdb import set_trace as st
from pprint import pprint

import logging
import numpy as np # for arrays, array broadcasting etc.
import numbers
from copy import deepcopy
from shutil import copyfile
from os.path import isfile
Expand Down Expand Up @@ -391,7 +386,7 @@ def inference(self, chunk, author2doc, doc2author, rhot, collect_sstats=False, c
doc_no = d
# Get the IDs and counts of all the words in the current document.
# TODO: this is duplication of code in LdaModel. Refactor.
if doc and not isinstance(doc[0][0], six.integer_types):
if doc and not isinstance(doc[0][0], six.integer_types + (np.integer,)):
# make sure the term IDs are ints, otherwise np will get upset
ids = [int(id) for id, _ in doc]
else:
Expand Down
20 changes: 10 additions & 10 deletions gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,17 @@
from collections import namedtuple, defaultdict
from timeit import default_timer

from numpy import zeros, random, sum as np_sum, add as np_add, concatenate, \
from numpy import zeros, sum as np_sum, add as np_add, concatenate, \
repeat as np_repeat, array, float32 as REAL, empty, ones, memmap as np_memmap, \
sqrt, newaxis, ndarray, dot, vstack, dtype, divide as np_divide
sqrt, newaxis, ndarray, dot, vstack, dtype, divide as np_divide, integer


from gensim.utils import call_on_class_only
from gensim import utils, matutils # utility fnc for pickling, common scipy operations etc
from gensim.models.word2vec import Word2Vec, train_cbow_pair, train_sg_pair, train_batch_sg
from gensim.models.keyedvectors import KeyedVectors
from six.moves import xrange, zip
from six import string_types, integer_types, itervalues
from six import string_types, integer_types

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(self, mapfile_path=None):

def note_doctag(self, key, document_no, document_length):
"""Note a document tag during initial corpus scan, for structure sizing."""
if isinstance(key, int):
if isinstance(key, integer_types + (integer,)):
self.max_rawint = max(self.max_rawint, key)
else:
if key in self.doctags:
Expand All @@ -319,7 +319,7 @@ def trained_item(self, indexed_tuple):

def _int_index(self, index):
"""Return int index for either string or int index"""
if isinstance(index, int):
if isinstance(index, integer_types + (integer,)):
return index
else:
return self.max_rawint + 1 + self.doctags[index].offset
Expand Down Expand Up @@ -347,7 +347,7 @@ def __getitem__(self, index):
If a list, return designated tags' vector representations as a
2D numpy array: #tags x #vector_size.
"""
if isinstance(index, string_types + (int,)):
if isinstance(index, string_types + integer_types + (integer,)):
return self.doctag_syn0[self._int_index(index)]

return vstack([self[i] for i in index])
Expand All @@ -356,7 +356,7 @@ def __len__(self):
return self.count

def __contains__(self, index):
if isinstance(index, int):
if isinstance(index, integer_types + (integer,)):
return index < self.count
else:
return index in self.doctags
Expand Down Expand Up @@ -439,17 +439,17 @@ def most_similar(self, positive=[], negative=[], topn=10, clip_start=0, clip_end
self.init_sims()
clip_end = clip_end or len(self.doctag_syn0norm)

if isinstance(positive, string_types + integer_types) and not negative:
if isinstance(positive, string_types + integer_types + (integer,)) and not negative:
# allow calls like most_similar('dog'), as a shorthand for most_similar(['dog'])
positive = [positive]

# add weights for each doc, if not already present; default to 1.0 for positive and -1.0 for negative docs
positive = [
(doc, 1.0) if isinstance(doc, string_types + (ndarray,) + integer_types)
(doc, 1.0) if isinstance(doc, string_types + integer_types + (ndarray, integer))
else doc for doc in positive
]
negative = [
(doc, -1.0) if isinstance(doc, string_types + (ndarray,) + integer_types)
(doc, -1.0) if isinstance(doc, string_types + integer_types + (ndarray, integer))
else doc for doc in negative
]

Expand Down
2 changes: 1 addition & 1 deletion gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def inference(self, chunk, collect_sstats=False):
# Lee&Seung trick which speeds things up by an order of magnitude, compared
# to Blei's original LDA-C code, cool!).
for d, doc in enumerate(chunk):
if len(doc) > 0 and not isinstance(doc[0][0], six.integer_types):
if len(doc) > 0 and not isinstance(doc[0][0], six.integer_types + (np.integer,)):
# make sure the term IDs are ints, otherwise np will get upset
ids = [int(id) for id, _ in doc]
else:
Expand Down
1 change: 1 addition & 0 deletions gensim/test/test_corpora.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def test_indexing(self):

for idx, doc in enumerate(docs):
self.assertEqual(doc, corpus[idx])
self.assertEqual(doc, corpus[np.int64(idx)])

self.assertEqual(docs, list(corpus[:]))
self.assertEqual(docs[0:], list(corpus[0:]))
Expand Down
5 changes: 3 additions & 2 deletions gensim/test/test_doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_int_doctags(self):
model.build_vocab(corpus)
self.assertEqual(len(model.docvecs.doctag_syn0), 300)
self.assertEqual(model.docvecs[0].shape, (100,))
self.assertEqual(model.docvecs[np.int64(0)].shape, (100,))
self.assertRaises(KeyError, model.__getitem__, '_*0')

def test_missing_string_doctag(self):
Expand Down Expand Up @@ -164,7 +165,7 @@ def test_similarity_unseen_docs(self):
def model_sanity(self, model, keep_training=True):
"""Any non-trivial model on DocsLeeCorpus can pass these sanity checks"""
fire1 = 0 # doc 0 sydney fires
fire2 = 8 # doc 8 sydney fires
fire2 = np.int64(8) # doc 8 sydney fires
tennis1 = 6 # doc 6 tennis

# inferred vector should be top10 close to bulk-trained one
Expand Down Expand Up @@ -304,7 +305,7 @@ def test_mixed_tag_types(self):
model = doc2vec.Doc2Vec()
model.build_vocab(mixed_tag_corpus)
expected_length = len(sentences) + len(model.docvecs.doctags) # 9 sentences, 7 unique first tokens
self.assertEquals(len(model.docvecs.doctag_syn0), expected_length)
self.assertEqual(len(model.docvecs.doctag_syn0), expected_length)

def models_equal(self, model, model2):
# check words/hidden-weights
Expand Down

0 comments on commit 50a18b9

Please sign in to comment.