diff --git a/gensim/corpora/textcorpus.py b/gensim/corpora/textcorpus.py index e52b60f32b..00e69a6717 100644 --- a/gensim/corpora/textcorpus.py +++ b/gensim/corpora/textcorpus.py @@ -98,28 +98,56 @@ def get_texts(self): else: yield utils.tokenize(line, lowercase=True) - def sample_texts(self, n): + def sample_texts(self, n, seed=None, length=None): """ - Yield n random texts from the corpus without replacement. + Yield n random documents from the corpus without replacement. - Given the the number of remaingin elements in stream is remaining and we need - to choose n elements, the probability for current element to be chosen is n/remaining. - If we choose it, we just decreese the n and move to the next element. + Given the number of remaining documents in a corpus, we need to choose n elements. + The probability for the current element to be chosen is n/remaining. + If we choose it, we just decrease the n and move to the next element. + Computing the corpus length may be a costly operation so you can use the optional + parameter `length` instead. + + Args: + n (int): number of documents we want to sample. + seed (int|None): if specified, use it as a seed for local random generator. + length (int|None): if specified, use it as a guess of corpus length. + It must be positive and not greater than actual corpus length. + + Yields: + list[str]: document represented as a list of tokens. See get_texts method. + + Raises: + ValueError: when n is invalid or length was set incorrectly. """ - length = len(self) - if not n <= length: - raise ValueError("sample larger than population") + random_generator = None + if seed is None: + random_generator = random + else: + random_generator = random.Random(seed) + + if length is None: + length = len(self) + if not n <= length: + raise ValueError("n is larger than length of corpus.") if not 0 <= n: - raise ValueError("negative sample size") + raise ValueError("Negative sample size.") for i, sample in enumerate(self.get_texts()): - remaining_in_stream = length - i - chance = random.randint(1, remaining_in_stream) + if i == length: + break + remaining_in_corpus = length - i + chance = random_generator.randint(1, remaining_in_corpus) if chance <= n: n -= 1 yield sample + if n != 0: + # This means that length was set to be greater than number of items in corpus + # and we were not able to sample enough documents before the stream ended. + raise ValueError("length greater than number of documents in corpus") + def __len__(self): if not hasattr(self, 'length'): # cache the corpus length diff --git a/gensim/test/test_textcorpus.py b/gensim/test/test_textcorpus.py index abf646eb97..82e3d80960 100644 --- a/gensim/test/test_textcorpus.py +++ b/gensim/test/test_textcorpus.py @@ -21,34 +21,48 @@ class TestTextCorpus(unittest.TestCase): # TODO add tests for other methods - def test_sample_text(self): - class TestTextCorpus(TextCorpus): - def __init__(self): - self.data = [["document1"], ["document2"]] + class DummyTextCorpus(TextCorpus): + def __init__(self): + self.size = 10 + self.data = [["document%s" % i] for i in range(self.size)] - def get_texts(self): - for document in self.data: - yield document + def get_texts(self): + for document in self.data: + yield document - corpus = TestTextCorpus() + def test_sample_text(self): + corpus = self.DummyTextCorpus() sample1 = list(corpus.sample_texts(1)) self.assertEqual(len(sample1), 1) - document1 = sample1[0] == ["document1"] - document2 = sample1[0] == ["document2"] - self.assertTrue(document1 or document2) + self.assertIn(sample1[0], corpus.data) - sample2 = list(corpus.sample_texts(2)) - self.assertEqual(len(sample2), 2) - self.assertEqual(sample2[0], ["document1"]) - self.assertEqual(sample2[1], ["document2"]) + sample2 = list(corpus.sample_texts(corpus.size)) + self.assertEqual(len(sample2), corpus.size) + for i in range(corpus.size): + self.assertEqual(sample2[i], ["document%s" % i]) with self.assertRaises(ValueError): - list(corpus.sample_texts(3)) + list(corpus.sample_texts(corpus.size + 1)) with self.assertRaises(ValueError): list(corpus.sample_texts(-1)) + def test_sample_text_length(self): + corpus = self.DummyTextCorpus() + sample1 = list(corpus.sample_texts(1, length=1)) + self.assertEqual(sample1[0], ["document0"]) + + sample2 = list(corpus.sample_texts(2, length=2)) + self.assertEqual(sample2[0], ["document0"]) + self.assertEqual(sample2[1], ["document1"]) + + def test_sample_text_seed(self): + corpus = self.DummyTextCorpus() + sample1 = list(corpus.sample_texts(5, seed=42)) + sample2 = list(corpus.sample_texts(5, seed=42)) + self.assertEqual(sample1, sample2) + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)