diff --git a/gensim/test/test_corpora.py b/gensim/test/test_corpora.py index 3e47531a2c..c948080ef3 100644 --- a/gensim/test/test_corpora.py +++ b/gensim/test/test_corpora.py @@ -1,954 +1,988 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (C) 2010 Radim Rehurek -# Licensed under the GNU LGPL v2.1 - https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html - -""" -Automated tests for checking corpus I/O formats (the corpora package). -""" - -from __future__ import unicode_literals - -import codecs -import itertools -import logging -import os -import os.path -import tempfile -import unittest - -import numpy as np - -from gensim.corpora import (bleicorpus, mmcorpus, lowcorpus, svmlightcorpus, - ucicorpus, malletcorpus, textcorpus, indexedcorpus, wikicorpus) -from gensim.interfaces import TransformedCorpus -from gensim.utils import to_unicode -from gensim.test.utils import datapath, get_tmpfile, common_corpus - - -GITHUB_ACTIONS_WINDOWS = os.environ.get('RUNNER_OS') == 'Windows' - - -class DummyTransformer: - def __getitem__(self, bow): - if len(next(iter(bow))) == 2: - # single bag of words - transformed = [(termid, count + 1) for termid, count in bow] - else: - # sliced corpus - transformed = [[(termid, count + 1) for termid, count in doc] for doc in bow] - return transformed - - -class CorpusTestCase(unittest.TestCase): - TEST_CORPUS = [[(1, 1.0)], [], [(0, 0.5), (2, 1.0)], []] - - def setUp(self): - self.corpus_class = None - self.file_extension = None - - def run(self, result=None): - if type(self) is not CorpusTestCase: - super(CorpusTestCase, self).run(result) - - def tearDown(self): - # remove all temporary test files - fname = get_tmpfile('gensim_corpus.tst') - extensions = ['', '', '.bz2', '.gz', '.index', '.vocab'] - for ext in itertools.permutations(extensions, 2): - try: - os.remove(fname + ext[0] + ext[1]) - except OSError: - pass - - @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') - def test_load(self): - fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) - corpus = self.corpus_class(fname) - - docs = list(corpus) - # the deerwester corpus always has nine documents - self.assertEqual(len(docs), 9) - - @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') - def test_len(self): - fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) - corpus = self.corpus_class(fname) - - # make sure corpus.index works, too - corpus = self.corpus_class(fname) - self.assertEqual(len(corpus), 9) - - # for subclasses of IndexedCorpus, we need to nuke this so we don't - # test length on the index, but just testcorpus contents - if hasattr(corpus, 'index'): - corpus.index = None - - self.assertEqual(len(corpus), 9) - - @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') - def test_empty_input(self): - tmpf = get_tmpfile('gensim_corpus.tst') - with open(tmpf, 'w') as f: - f.write('') - - with open(tmpf + '.vocab', 'w') as f: - f.write('') - - corpus = self.corpus_class(tmpf) - self.assertEqual(len(corpus), 0) - - docs = list(corpus) - self.assertEqual(len(docs), 0) - - @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') - def test_save(self): - corpus = self.TEST_CORPUS - tmpf = get_tmpfile('gensim_corpus.tst') - - # make sure the corpus can be saved - self.corpus_class.save_corpus(tmpf, corpus) - - # and loaded back, resulting in exactly the same corpus - corpus2 = list(self.corpus_class(tmpf)) - self.assertEqual(corpus, corpus2) - - @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') - def test_serialize(self): - corpus = self.TEST_CORPUS - tmpf = get_tmpfile('gensim_corpus.tst') - - # make sure the corpus can be saved - self.corpus_class.serialize(tmpf, corpus) - - # and loaded back, resulting in exactly the same corpus - corpus2 = self.corpus_class(tmpf) - self.assertEqual(corpus, list(corpus2)) - - # make sure the indexing corpus[i] works - for i in range(len(corpus)): - self.assertEqual(corpus[i], corpus2[i]) - - # make sure that subclasses of IndexedCorpus support fancy indexing - # after deserialisation - if isinstance(corpus, indexedcorpus.IndexedCorpus): - idx = [1, 3, 5, 7] - self.assertEqual(corpus[idx], corpus2[idx]) - - @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') - def test_serialize_compressed(self): - corpus = self.TEST_CORPUS - tmpf = get_tmpfile('gensim_corpus.tst') - - for extension in ['.gz', '.bz2']: - fname = tmpf + extension - # make sure the corpus can be saved - self.corpus_class.serialize(fname, corpus) - - # and loaded back, resulting in exactly the same corpus - corpus2 = self.corpus_class(fname) - self.assertEqual(corpus, list(corpus2)) - - # make sure the indexing `corpus[i]` syntax works - for i in range(len(corpus)): - self.assertEqual(corpus[i], corpus2[i]) - - @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') - def test_switch_id2word(self): - fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) - corpus = self.corpus_class(fname) - if hasattr(corpus, 'id2word'): - firstdoc = next(iter(corpus)) - testdoc = set((to_unicode(corpus.id2word[x]), y) for x, y in firstdoc) - - self.assertEqual(testdoc, {('computer', 1), ('human', 1), ('interface', 1)}) - - d = corpus.id2word - d[0], d[1] = d[1], d[0] - corpus.id2word = d - - firstdoc2 = next(iter(corpus)) - testdoc2 = set((to_unicode(corpus.id2word[x]), y) for x, y in firstdoc2) - self.assertEqual(testdoc2, {('computer', 1), ('human', 1), ('interface', 1)}) - - @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') - def test_indexing(self): - fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) - corpus = self.corpus_class(fname) - docs = list(corpus) - - for idx, doc in enumerate(docs): - self.assertEqual(doc, corpus[idx]) - self.assertEqual(doc, corpus[np.int64(idx)]) - - self.assertEqual(docs, list(corpus[:])) - self.assertEqual(docs[0:], list(corpus[0:])) - self.assertEqual(docs[0:-1], list(corpus[0:-1])) - self.assertEqual(docs[2:4], list(corpus[2:4])) - self.assertEqual(docs[::2], list(corpus[::2])) - self.assertEqual(docs[::-1], list(corpus[::-1])) - - # make sure sliced corpora can be iterated over multiple times - c = corpus[:] - self.assertEqual(docs, list(c)) - self.assertEqual(docs, list(c)) - self.assertEqual(len(docs), len(corpus)) - self.assertEqual(len(docs), len(corpus[:])) - self.assertEqual(len(docs[::2]), len(corpus[::2])) - - def _get_slice(corpus, slice_): - # assertRaises for python 2.6 takes a callable - return corpus[slice_] - - # make sure proper input validation for sliced corpora is done - self.assertRaises(ValueError, _get_slice, corpus, {1}) - self.assertRaises(ValueError, _get_slice, corpus, 1.0) - - # check sliced corpora that use fancy indexing - c = corpus[[1, 3, 4]] - self.assertEqual([d for i, d in enumerate(docs) if i in [1, 3, 4]], list(c)) - self.assertEqual([d for i, d in enumerate(docs) if i in [1, 3, 4]], list(c)) - self.assertEqual(len(corpus[[0, 1, -1]]), 3) - self.assertEqual(len(corpus[np.asarray([0, 1, -1])]), 3) - - # check that TransformedCorpus supports indexing when the underlying - # corpus does, and throws an error otherwise - corpus_ = TransformedCorpus(DummyTransformer(), corpus) - if hasattr(corpus, 'index') and corpus.index is not None: - self.assertEqual(corpus_[0][0][1], docs[0][0][1] + 1) - self.assertRaises(ValueError, _get_slice, corpus_, {1}) - transformed_docs = [val + 1 for i, d in enumerate(docs) for _, val in d if i in [1, 3, 4]] - self.assertEqual(transformed_docs, list(v for doc in corpus_[[1, 3, 4]] for _, v in doc)) - self.assertEqual(3, len(corpus_[[1, 3, 4]])) - else: - self.assertRaises(RuntimeError, _get_slice, corpus_, [1, 3, 4]) - self.assertRaises(RuntimeError, _get_slice, corpus_, {1}) - self.assertRaises(RuntimeError, _get_slice, corpus_, 1.0) - - -class TestMmCorpusWithIndex(CorpusTestCase): - def setUp(self): - self.corpus_class = mmcorpus.MmCorpus - self.corpus = self.corpus_class(datapath('test_mmcorpus_with_index.mm')) - self.file_extension = '.mm' - - def test_serialize_compressed(self): - # MmCorpus needs file write with seek => doesn't support compressed output (only input) - pass - - def test_closed_file_object(self): - file_obj = open(datapath('testcorpus.mm')) - f = file_obj.closed - mmcorpus.MmCorpus(file_obj) - s = file_obj.closed - self.assertEqual(f, 0) - self.assertEqual(s, 0) - - @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') - def test_load(self): - self.assertEqual(self.corpus.num_docs, 9) - self.assertEqual(self.corpus.num_terms, 12) - self.assertEqual(self.corpus.num_nnz, 28) - - # confirm we can iterate and that document values match expected for first three docs - it = iter(self.corpus) - self.assertEqual(next(it), [(0, 1.0), (1, 1.0), (2, 1.0)]) - self.assertEqual(next(it), [(0, 1.0), (3, 1.0), (4, 1.0), (5, 1.0), (6, 1.0), (7, 1.0)]) - self.assertEqual(next(it), [(2, 1.0), (5, 1.0), (7, 1.0), (8, 1.0)]) - - # confirm that accessing document by index works - self.assertEqual(self.corpus[3], [(1, 1.0), (5, 2.0), (8, 1.0)]) - self.assertEqual(tuple(self.corpus.index), (97, 121, 169, 201, 225, 249, 258, 276, 303)) - - -class TestMmCorpusNoIndex(CorpusTestCase): - def setUp(self): - self.corpus_class = mmcorpus.MmCorpus - self.corpus = self.corpus_class(datapath('test_mmcorpus_no_index.mm')) - self.file_extension = '.mm' - - def test_serialize_compressed(self): - # MmCorpus needs file write with seek => doesn't support compressed output (only input) - pass - - def test_load(self): - self.assertEqual(self.corpus.num_docs, 9) - self.assertEqual(self.corpus.num_terms, 12) - self.assertEqual(self.corpus.num_nnz, 28) - - # confirm we can iterate and that document values match expected for first three docs - it = iter(self.corpus) - self.assertEqual(next(it), [(0, 1.0), (1, 1.0), (2, 1.0)]) - self.assertEqual(next(it), []) - self.assertEqual(next(it), [(2, 0.42371910849), (5, 0.6625174), (7, 1.0), (8, 1.0)]) - - # confirm that accessing document by index fails - self.assertRaises(RuntimeError, lambda: self.corpus[3]) - - -class TestMmCorpusNoIndexGzip(CorpusTestCase): - def setUp(self): - self.corpus_class = mmcorpus.MmCorpus - self.corpus = self.corpus_class(datapath('test_mmcorpus_no_index.mm.gz')) - self.file_extension = '.mm' - - def test_serialize_compressed(self): - # MmCorpus needs file write with seek => doesn't support compressed output (only input) - pass - - def test_load(self): - self.assertEqual(self.corpus.num_docs, 9) - self.assertEqual(self.corpus.num_terms, 12) - self.assertEqual(self.corpus.num_nnz, 28) - - # confirm we can iterate and that document values match expected for first three docs - it = iter(self.corpus) - self.assertEqual(next(it), [(0, 1.0), (1, 1.0), (2, 1.0)]) - self.assertEqual(next(it), []) - self.assertEqual(next(it), [(2, 0.42371910849), (5, 0.6625174), (7, 1.0), (8, 1.0)]) - - # confirm that accessing document by index fails - self.assertRaises(RuntimeError, lambda: self.corpus[3]) - - -class TestMmCorpusNoIndexBzip(CorpusTestCase): - def setUp(self): - self.corpus_class = mmcorpus.MmCorpus - self.corpus = self.corpus_class(datapath('test_mmcorpus_no_index.mm.bz2')) - self.file_extension = '.mm' - - def test_serialize_compressed(self): - # MmCorpus needs file write with seek => doesn't support compressed output (only input) - pass - - def test_load(self): - self.assertEqual(self.corpus.num_docs, 9) - self.assertEqual(self.corpus.num_terms, 12) - self.assertEqual(self.corpus.num_nnz, 28) - - # confirm we can iterate and that document values match expected for first three docs - it = iter(self.corpus) - self.assertEqual(next(it), [(0, 1.0), (1, 1.0), (2, 1.0)]) - self.assertEqual(next(it), []) - self.assertEqual(next(it), [(2, 0.42371910849), (5, 0.6625174), (7, 1.0), (8, 1.0)]) - - # confirm that accessing document by index fails - self.assertRaises(RuntimeError, lambda: self.corpus[3]) - - -class TestMmCorpusCorrupt(CorpusTestCase): - def setUp(self): - self.corpus_class = mmcorpus.MmCorpus - self.corpus = self.corpus_class(datapath('test_mmcorpus_corrupt.mm')) - self.file_extension = '.mm' - - def test_serialize_compressed(self): - # MmCorpus needs file write with seek => doesn't support compressed output (only input) - pass - - def test_load(self): - self.assertRaises(ValueError, lambda: [doc for doc in self.corpus]) - - -class TestMmCorpusOverflow(CorpusTestCase): - """ - Test to make sure cython mmreader doesn't overflow on large number of docs or terms - - """ - def setUp(self): - self.corpus_class = mmcorpus.MmCorpus - self.corpus = self.corpus_class(datapath('test_mmcorpus_overflow.mm')) - self.file_extension = '.mm' - - def test_serialize_compressed(self): - # MmCorpus needs file write with seek => doesn't support compressed output (only input) - pass - - def test_load(self): - self.assertEqual(self.corpus.num_docs, 44270060) - self.assertEqual(self.corpus.num_terms, 500) - self.assertEqual(self.corpus.num_nnz, 22134988630) - - # confirm we can iterate and that document values match expected for first three docs - it = iter(self.corpus) - self.assertEqual(next(it)[:3], [(0, 0.3913027376444812), - (1, -0.07658791716226626), - (2, -0.020870794080588395)]) - self.assertEqual(next(it), []) - self.assertEqual(next(it), []) - - # confirm count of terms - count = 0 - for doc in self.corpus: - for term in doc: - count += 1 - - self.assertEqual(count, 12) - - # confirm that accessing document by index fails - self.assertRaises(RuntimeError, lambda: self.corpus[3]) - - -class TestSvmLightCorpus(CorpusTestCase): - def setUp(self): - self.corpus_class = svmlightcorpus.SvmLightCorpus - self.file_extension = '.svmlight' - - def test_serialization(self): - path = get_tmpfile("svml.corpus") - labels = [1] * len(common_corpus) - second_corpus = [(0, 1.0), (3, 1.0), (4, 1.0), (5, 1.0), (6, 1.0), (7, 1.0)] - self.corpus_class.serialize(path, common_corpus, labels=labels) - serialized_corpus = self.corpus_class(path) - self.assertEqual(serialized_corpus[1], second_corpus) - self.corpus_class.serialize(path, common_corpus, labels=np.array(labels)) - serialized_corpus = self.corpus_class(path) - self.assertEqual(serialized_corpus[1], second_corpus) - - -class TestBleiCorpus(CorpusTestCase): - def setUp(self): - self.corpus_class = bleicorpus.BleiCorpus - self.file_extension = '.blei' - - def test_save_format_for_dtm(self): - corpus = [[(1, 1.0)], [], [(0, 5.0), (2, 1.0)], []] - test_file = get_tmpfile('gensim_corpus.tst') - self.corpus_class.save_corpus(test_file, corpus) - with open(test_file) as f: - for line in f: - # unique_word_count index1:count1 index2:count2 ... indexn:count - tokens = line.split() - words_len = int(tokens[0]) - if words_len > 0: - tokens = tokens[1:] - else: - tokens = [] - self.assertEqual(words_len, len(tokens)) - for token in tokens: - word, count = token.split(':') - self.assertEqual(count, str(int(count))) - - -class TestLowCorpus(CorpusTestCase): - TEST_CORPUS = [[(1, 1)], [], [(0, 2), (2, 1)], []] - CORPUS_LINE = 'mom wash window window was washed' - - def setUp(self): - self.corpus_class = lowcorpus.LowCorpus - self.file_extension = '.low' - - def test_line2doc(self): - fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) - id2word = {1: 'mom', 2: 'window'} - - corpus = self.corpus_class(fname, id2word=id2word) - - # should return all words in doc - corpus.use_wordids = False - self.assertEqual( - sorted(corpus.line2doc(self.CORPUS_LINE)), - [('mom', 1), ('was', 1), ('wash', 1), ('washed', 1), ('window', 2)]) - - # should return words in word2id - corpus.use_wordids = True - self.assertEqual( - sorted(corpus.line2doc(self.CORPUS_LINE)), - [(1, 1), (2, 2)]) - - -class TestUciCorpus(CorpusTestCase): - TEST_CORPUS = [[(1, 1)], [], [(0, 2), (2, 1)], []] - - def setUp(self): - self.corpus_class = ucicorpus.UciCorpus - self.file_extension = '.uci' - - def test_serialize_compressed(self): - # UciCorpus needs file write with seek => doesn't support compressed output (only input) - pass - - -class TestMalletCorpus(TestLowCorpus): - TEST_CORPUS = [[(1, 1)], [], [(0, 2), (2, 1)], []] - CORPUS_LINE = '#3 lang mom wash window window was washed' - - def setUp(self): - self.corpus_class = malletcorpus.MalletCorpus - self.file_extension = '.mallet' - - def test_load_with_metadata(self): - fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) - corpus = self.corpus_class(fname) - corpus.metadata = True - self.assertEqual(len(corpus), 9) - - docs = list(corpus) - self.assertEqual(len(docs), 9) - - for i, docmeta in enumerate(docs): - doc, metadata = docmeta - self.assertEqual(metadata[0], str(i + 1)) - self.assertEqual(metadata[1], 'en') - - def test_line2doc(self): - # case with metadata=False (by default) - super(TestMalletCorpus, self).test_line2doc() - - # case with metadata=True - fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) - id2word = {1: 'mom', 2: 'window'} - - corpus = self.corpus_class(fname, id2word=id2word, metadata=True) - - # should return all words in doc - corpus.use_wordids = False - doc, (docid, doclang) = corpus.line2doc(self.CORPUS_LINE) - self.assertEqual(docid, '#3') - self.assertEqual(doclang, 'lang') - self.assertEqual( - sorted(doc), - [('mom', 1), ('was', 1), ('wash', 1), ('washed', 1), ('window', 2)]) - - # should return words in word2id - corpus.use_wordids = True - doc, (docid, doclang) = corpus.line2doc(self.CORPUS_LINE) - - self.assertEqual(docid, '#3') - self.assertEqual(doclang, 'lang') - self.assertEqual( - sorted(doc), - [(1, 1), (2, 2)]) - - -class TestTextCorpus(CorpusTestCase): - - def setUp(self): - self.corpus_class = textcorpus.TextCorpus - self.file_extension = '.txt' - - def test_load_with_metadata(self): - fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) - corpus = self.corpus_class(fname) - corpus.metadata = True - self.assertEqual(len(corpus), 9) - - docs = list(corpus) - self.assertEqual(len(docs), 9) - - for i, docmeta in enumerate(docs): - doc, metadata = docmeta - self.assertEqual(metadata[0], i) - - def test_default_preprocessing(self): - lines = [ - "Šéf chomutovských komunistů dostal poštou bílý prášek", - "this is a test for stopwords", - "zf tooth spaces " - ] - expected = [ - ['Sef', 'chomutovskych', 'komunistu', 'dostal', 'postou', 'bily', 'prasek'], - ['test', 'stopwords'], - ['tooth', 'spaces'] - ] - - corpus = self.corpus_from_lines(lines) - texts = list(corpus.get_texts()) - self.assertEqual(expected, texts) - - def corpus_from_lines(self, lines): - fpath = tempfile.mktemp() - with codecs.open(fpath, 'w', encoding='utf8') as f: - f.write('\n'.join(lines)) - - return self.corpus_class(fpath) - - def test_sample_text(self): - lines = ["document%d" % i for i in range(10)] - corpus = self.corpus_from_lines(lines) - corpus.tokenizer = lambda text: text.split() - docs = [doc for doc in corpus.get_texts()] - - sample1 = list(corpus.sample_texts(1)) - self.assertEqual(len(sample1), 1) - self.assertIn(sample1[0], docs) - - sample2 = list(corpus.sample_texts(len(lines))) - self.assertEqual(len(sample2), len(corpus)) - for i in range(len(corpus)): - self.assertEqual(sample2[i], ["document%s" % i]) - - with self.assertRaises(ValueError): - list(corpus.sample_texts(len(corpus) + 1)) - - with self.assertRaises(ValueError): - list(corpus.sample_texts(-1)) - - def test_sample_text_length(self): - lines = ["document%d" % i for i in range(10)] - corpus = self.corpus_from_lines(lines) - corpus.tokenizer = lambda text: text.split() - - sample1 = list(corpus.sample_texts(1, length=1)) - self.assertEqual(sample1[0], ["document0"]) - - sample2 = list(corpus.sample_texts(2, length=2)) - self.assertEqual(sample2[0], ["document0"]) - self.assertEqual(sample2[1], ["document1"]) - - def test_sample_text_seed(self): - lines = ["document%d" % i for i in range(10)] - corpus = self.corpus_from_lines(lines) - - sample1 = list(corpus.sample_texts(5, seed=42)) - sample2 = list(corpus.sample_texts(5, seed=42)) - self.assertEqual(sample1, sample2) - - def test_save(self): - pass - - def test_serialize(self): - pass - - def test_serialize_compressed(self): - pass - - def test_indexing(self): - pass - - -# Needed for the test_custom_tokenizer is the TestWikiCorpus class. -# Cannot be nested due to serializing. -def custom_tokenizer(content, token_min_len=2, token_max_len=15, lower=True): - return [ - to_unicode(token.lower()) if lower else to_unicode(token) for token in content.split() - if token_min_len <= len(token) <= token_max_len and not token.startswith('_') - ] - - -class TestWikiCorpus(TestTextCorpus): - def setUp(self): - self.corpus_class = wikicorpus.WikiCorpus - self.file_extension = '.xml.bz2' - self.fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) - self.enwiki = datapath('enwiki-latest-pages-articles1.xml-p000000010p000030302-shortened.bz2') - - def test_default_preprocessing(self): - expected = ['computer', 'human', 'interface'] - corpus = self.corpus_class(self.fname, article_min_tokens=0) - first_text = next(corpus.get_texts()) - self.assertEqual(expected, first_text) - - def test_len(self): - # When there is no min_token limit all 9 articles must be registered. - corpus = self.corpus_class(self.fname, article_min_tokens=0) - all_articles = corpus.get_texts() - assert (len(list(all_articles)) == 9) - - # With a huge min_token limit, all articles should be filtered out. - corpus = self.corpus_class(self.fname, article_min_tokens=100000) - all_articles = corpus.get_texts() - assert (len(list(all_articles)) == 0) - - def test_load_with_metadata(self): - corpus = self.corpus_class(self.fname, article_min_tokens=0) - corpus.metadata = True - self.assertEqual(len(corpus), 9) - - docs = list(corpus) - self.assertEqual(len(docs), 9) - - for i, docmeta in enumerate(docs): - doc, metadata = docmeta - article_no = i + 1 # Counting IDs from 1 - self.assertEqual(metadata[0], str(article_no)) - self.assertEqual(metadata[1], 'Article%d' % article_no) - - def test_load(self): - corpus = self.corpus_class(self.fname, article_min_tokens=0) - - docs = list(corpus) - # the deerwester corpus always has nine documents - self.assertEqual(len(docs), 9) - - def test_first_element(self): - """ - First two articles in this sample are - 1) anarchism - 2) autism - """ - corpus = self.corpus_class(self.enwiki, processes=1) - - texts = corpus.get_texts() - self.assertTrue(u'anarchism' in next(texts)) - self.assertTrue(u'autism' in next(texts)) - - def test_unicode_element(self): - """ - First unicode article in this sample is - 1) папа - """ - bgwiki = datapath('bgwiki-latest-pages-articles-shortened.xml.bz2') - corpus = self.corpus_class(bgwiki) - texts = corpus.get_texts() - self.assertTrue(u'папа' in next(texts)) - - def test_custom_tokenizer(self): - """ - define a custom tokenizer function and use it - """ - wc = self.corpus_class(self.enwiki, processes=1, tokenizer_func=custom_tokenizer, - token_max_len=16, token_min_len=1, lower=False) - row = wc.get_texts() - list_tokens = next(row) - self.assertTrue(u'Anarchism' in list_tokens) - self.assertTrue(u'collectivization' in list_tokens) - self.assertTrue(u'a' in list_tokens) - self.assertTrue(u'i.e.' in list_tokens) - - def test_lower_case_set_true(self): - """ - Set the parameter lower to True and check that upper case 'Anarchism' token doesnt exist - """ - corpus = self.corpus_class(self.enwiki, processes=1, lower=True) - row = corpus.get_texts() - list_tokens = next(row) - self.assertTrue(u'Anarchism' not in list_tokens) - self.assertTrue(u'anarchism' in list_tokens) - - def test_lower_case_set_false(self): - """ - Set the parameter lower to False and check that upper case Anarchism' token exists - """ - corpus = self.corpus_class(self.enwiki, processes=1, lower=False) - row = corpus.get_texts() - list_tokens = next(row) - self.assertTrue(u'Anarchism' in list_tokens) - self.assertTrue(u'anarchism' in list_tokens) - - def test_min_token_len_not_set(self): - """ - Don't set the parameter token_min_len and check that 'a' as a token doesn't exist - Default token_min_len=2 - """ - corpus = self.corpus_class(self.enwiki, processes=1) - self.assertTrue(u'a' not in next(corpus.get_texts())) - - def test_min_token_len_set(self): - """ - Set the parameter token_min_len to 1 and check that 'a' as a token exists - """ - corpus = self.corpus_class(self.enwiki, processes=1, token_min_len=1) - self.assertTrue(u'a' in next(corpus.get_texts())) - - def test_max_token_len_not_set(self): - """ - Don't set the parameter token_max_len and check that 'collectivisation' as a token doesn't exist - Default token_max_len=15 - """ - corpus = self.corpus_class(self.enwiki, processes=1) - self.assertTrue(u'collectivization' not in next(corpus.get_texts())) - - def test_max_token_len_set(self): - """ - Set the parameter token_max_len to 16 and check that 'collectivisation' as a token exists - """ - corpus = self.corpus_class(self.enwiki, processes=1, token_max_len=16) - self.assertTrue(u'collectivization' in next(corpus.get_texts())) - - def test_removed_table_markup(self): - """ - Check if all the table markup has been removed. - """ - enwiki_file = datapath('enwiki-table-markup.xml.bz2') - corpus = self.corpus_class(enwiki_file) - texts = corpus.get_texts() - table_markup = ["style", "class", "border", "cellspacing", "cellpadding", "colspan", "rowspan"] - for text in texts: - for word in table_markup: - self.assertTrue(word not in text) - - def test_get_stream(self): - wiki = self.corpus_class(self.enwiki) - sample_text_wiki = next(wiki.getstream()).decode()[1:14] - self.assertEqual(sample_text_wiki, "mediawiki xml") - - # #TODO: sporadic failure to be investigated - # def test_get_texts_returns_generator_of_lists(self): - # corpus = self.corpus_class(self.enwiki) - # l = corpus.get_texts() - # self.assertEqual(type(l), types.GeneratorType) - # first = next(l) - # self.assertEqual(type(first), list) - # self.assertTrue(isinstance(first[0], bytes) or isinstance(first[0], str)) - - def test_sample_text(self): - # Cannot instantiate WikiCorpus from lines - pass - - def test_sample_text_length(self): - # Cannot instantiate WikiCorpus from lines - pass - - def test_sample_text_seed(self): - # Cannot instantiate WikiCorpus from lines - pass - - def test_empty_input(self): - # An empty file is not legit XML - pass - - def test_custom_filterfunction(self): - def reject_all(elem, *args, **kwargs): - return False - corpus = self.corpus_class(self.enwiki, filter_articles=reject_all) - texts = corpus.get_texts() - self.assertFalse(any(texts)) - - def keep_some(elem, title, *args, **kwargs): - return title[0] == 'C' - corpus = self.corpus_class(self.enwiki, filter_articles=reject_all) - corpus.metadata = True - texts = corpus.get_texts() - for text, (pageid, title) in texts: - self.assertEquals(title[0], 'C') - - -class TestTextDirectoryCorpus(unittest.TestCase): - - def write_one_level(self, *args): - if not args: - args = ('doc1', 'doc2') - dirpath = tempfile.mkdtemp() - self.write_docs_to_directory(dirpath, *args) - return dirpath - - def write_docs_to_directory(self, dirpath, *args): - for doc_num, name in enumerate(args): - with open(os.path.join(dirpath, name), 'w') as f: - f.write('document %d content' % doc_num) - - def test_one_level_directory(self): - dirpath = self.write_one_level() - - corpus = textcorpus.TextDirectoryCorpus(dirpath) - self.assertEqual(len(corpus), 2) - docs = list(corpus) - self.assertEqual(len(docs), 2) - - def write_two_levels(self): - dirpath = self.write_one_level() - next_level = os.path.join(dirpath, 'level_two') - os.mkdir(next_level) - self.write_docs_to_directory(next_level, 'doc1', 'doc2') - return dirpath, next_level - - def test_two_level_directory(self): - dirpath, next_level = self.write_two_levels() - - corpus = textcorpus.TextDirectoryCorpus(dirpath) - self.assertEqual(len(corpus), 4) - docs = list(corpus) - self.assertEqual(len(docs), 4) - - corpus = textcorpus.TextDirectoryCorpus(dirpath, min_depth=1) - self.assertEqual(len(corpus), 2) - docs = list(corpus) - self.assertEqual(len(docs), 2) - - corpus = textcorpus.TextDirectoryCorpus(dirpath, max_depth=0) - self.assertEqual(len(corpus), 2) - docs = list(corpus) - self.assertEqual(len(docs), 2) - - def test_filename_filtering(self): - dirpath = self.write_one_level('test1.log', 'test1.txt', 'test2.log', 'other1.log') - corpus = textcorpus.TextDirectoryCorpus(dirpath, pattern=r"test.*\.log") - filenames = list(corpus.iter_filepaths()) - expected = [os.path.join(dirpath, name) for name in ('test1.log', 'test2.log')] - self.assertEqual(sorted(expected), sorted(filenames)) - - corpus.pattern = ".*.txt" - filenames = list(corpus.iter_filepaths()) - expected = [os.path.join(dirpath, 'test1.txt')] - self.assertEqual(expected, filenames) - - corpus.pattern = None - corpus.exclude_pattern = ".*.log" - filenames = list(corpus.iter_filepaths()) - self.assertEqual(expected, filenames) - - def test_lines_are_documents(self): - dirpath = tempfile.mkdtemp() - lines = ['doc%d text' % i for i in range(5)] - fpath = os.path.join(dirpath, 'test_file.txt') - with open(fpath, 'w') as f: - f.write('\n'.join(lines)) - - corpus = textcorpus.TextDirectoryCorpus(dirpath, lines_are_documents=True) - docs = [doc for doc in corpus.getstream()] - self.assertEqual(len(lines), corpus.length) # should have cached - self.assertEqual(lines, docs) - - corpus.lines_are_documents = False - docs = [doc for doc in corpus.getstream()] - self.assertEqual(1, corpus.length) - self.assertEqual('\n'.join(lines), docs[0]) - - def test_non_trivial_structure(self): - """Test with non-trivial directory structure, shown below: - . - ├── 0.txt - ├── a_folder - │ └── 1.txt - └── b_folder - ├── 2.txt - ├── 3.txt - └── c_folder - └── 4.txt - """ - dirpath = tempfile.mkdtemp() - self.write_docs_to_directory(dirpath, '0.txt') - - a_folder = os.path.join(dirpath, 'a_folder') - os.mkdir(a_folder) - self.write_docs_to_directory(a_folder, '1.txt') - - b_folder = os.path.join(dirpath, 'b_folder') - os.mkdir(b_folder) - self.write_docs_to_directory(b_folder, '2.txt', '3.txt') - - c_folder = os.path.join(b_folder, 'c_folder') - os.mkdir(c_folder) - self.write_docs_to_directory(c_folder, '4.txt') - - corpus = textcorpus.TextDirectoryCorpus(dirpath) - filenames = list(corpus.iter_filepaths()) - base_names = sorted(name[len(dirpath) + 1:] for name in filenames) - expected = sorted([ - '0.txt', - 'a_folder/1.txt', - 'b_folder/2.txt', - 'b_folder/3.txt', - 'b_folder/c_folder/4.txt' - ]) - expected = [os.path.normpath(path) for path in expected] - self.assertEqual(expected, base_names) - - corpus.max_depth = 1 - self.assertEqual(expected[:-1], base_names[:-1]) - - corpus.min_depth = 1 - self.assertEqual(expected[2:-1], base_names[2:-1]) - - corpus.max_depth = 0 - self.assertEqual(expected[2:], base_names[2:]) - - corpus.pattern = "4.*" - self.assertEqual(expected[-1], base_names[-1]) - - -if __name__ == '__main__': - logging.basicConfig(level=logging.DEBUG) - unittest.main() +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2010 Radim Rehurek +# Licensed under the GNU LGPL v2.1 - https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html + +""" +Automated tests for checking corpus I/O formats (the corpora package). +""" + +from __future__ import unicode_literals + +import codecs +import itertools +import logging +import os +import os.path +import tempfile +import unittest + +import numpy as np + +from gensim.corpora import (bleicorpus, mmcorpus, lowcorpus, svmlightcorpus, + ucicorpus, malletcorpus, textcorpus, indexedcorpus, wikicorpus) +from gensim.interfaces import TransformedCorpus +from gensim.utils import to_unicode +from gensim.test.utils import datapath, get_tmpfile, common_corpus + + +GITHUB_ACTIONS_WINDOWS = os.environ.get('RUNNER_OS') == 'Windows' + + +class DummyTransformer: + def __getitem__(self, bow): + if len(next(iter(bow))) == 2: + # single bag of words + transformed = [(termid, count + 1) for termid, count in bow] + else: + # sliced corpus + transformed = [[(termid, count + 1) for termid, count in doc] for doc in bow] + return transformed + + +class CorpusTestCase(unittest.TestCase): + TEST_CORPUS = [[(1, 1.0)], [], [(0, 0.5), (2, 1.0)], []] + + def setUp(self): + self.corpus_class = None + self.file_extension = None + + def run(self, result=None): + if type(self) is not CorpusTestCase: + super(CorpusTestCase, self).run(result) + + def tearDown(self): + # remove all temporary test files + fname = get_tmpfile('gensim_corpus.tst') + extensions = ['', '', '.bz2', '.gz', '.index', '.vocab'] + for ext in itertools.permutations(extensions, 2): + try: + os.remove(fname + ext[0] + ext[1]) + except OSError: + pass + + @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') + def test_load(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + corpus = self.corpus_class(fname) + + docs = list(corpus) + # the deerwester corpus always has nine documents + self.assertEqual(len(docs), 9) + + @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') + def test_len(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + corpus = self.corpus_class(fname) + + # make sure corpus.index works, too + corpus = self.corpus_class(fname) + self.assertEqual(len(corpus), 9) + + # for subclasses of IndexedCorpus, we need to nuke this so we don't + # test length on the index, but just testcorpus contents + if hasattr(corpus, 'index'): + corpus.index = None + + self.assertEqual(len(corpus), 9) + + @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') + def test_empty_input(self): + tmpf = get_tmpfile('gensim_corpus.tst') + with open(tmpf, 'w') as f: + f.write('') + + with open(tmpf + '.vocab', 'w') as f: + f.write('') + + corpus = self.corpus_class(tmpf) + self.assertEqual(len(corpus), 0) + + docs = list(corpus) + self.assertEqual(len(docs), 0) + + @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') + def test_save(self): + corpus = self.TEST_CORPUS + tmpf = get_tmpfile('gensim_corpus.tst') + + # make sure the corpus can be saved + self.corpus_class.save_corpus(tmpf, corpus) + + # and loaded back, resulting in exactly the same corpus + corpus2 = list(self.corpus_class(tmpf)) + self.assertEqual(corpus, corpus2) + + @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') + def test_serialize(self): + corpus = self.TEST_CORPUS + tmpf = get_tmpfile('gensim_corpus.tst') + + # make sure the corpus can be saved + self.corpus_class.serialize(tmpf, corpus) + + # and loaded back, resulting in exactly the same corpus + corpus2 = self.corpus_class(tmpf) + self.assertEqual(corpus, list(corpus2)) + + # make sure the indexing corpus[i] works + for i in range(len(corpus)): + self.assertEqual(corpus[i], corpus2[i]) + + # make sure that subclasses of IndexedCorpus support fancy indexing + # after deserialisation + if isinstance(corpus, indexedcorpus.IndexedCorpus): + idx = [1, 3, 5, 7] + self.assertEqual(corpus[idx], corpus2[idx]) + + @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') + def test_serialize_compressed(self): + corpus = self.TEST_CORPUS + tmpf = get_tmpfile('gensim_corpus.tst') + + for extension in ['.gz', '.bz2']: + fname = tmpf + extension + # make sure the corpus can be saved + self.corpus_class.serialize(fname, corpus) + + # and loaded back, resulting in exactly the same corpus + corpus2 = self.corpus_class(fname) + self.assertEqual(corpus, list(corpus2)) + + # make sure the indexing `corpus[i]` syntax works + for i in range(len(corpus)): + self.assertEqual(corpus[i], corpus2[i]) + + @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') + def test_switch_id2word(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + corpus = self.corpus_class(fname) + if hasattr(corpus, 'id2word'): + firstdoc = next(iter(corpus)) + testdoc = set((to_unicode(corpus.id2word[x]), y) for x, y in firstdoc) + + self.assertEqual(testdoc, {('computer', 1), ('human', 1), ('interface', 1)}) + + d = corpus.id2word + d[0], d[1] = d[1], d[0] + corpus.id2word = d + + firstdoc2 = next(iter(corpus)) + testdoc2 = set((to_unicode(corpus.id2word[x]), y) for x, y in firstdoc2) + self.assertEqual(testdoc2, {('computer', 1), ('human', 1), ('interface', 1)}) + + @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') + def test_indexing(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + corpus = self.corpus_class(fname) + docs = list(corpus) + + for idx, doc in enumerate(docs): + self.assertEqual(doc, corpus[idx]) + self.assertEqual(doc, corpus[np.int64(idx)]) + + self.assertEqual(docs, list(corpus[:])) + self.assertEqual(docs[0:], list(corpus[0:])) + self.assertEqual(docs[0:-1], list(corpus[0:-1])) + self.assertEqual(docs[2:4], list(corpus[2:4])) + self.assertEqual(docs[::2], list(corpus[::2])) + self.assertEqual(docs[::-1], list(corpus[::-1])) + + # make sure sliced corpora can be iterated over multiple times + c = corpus[:] + self.assertEqual(docs, list(c)) + self.assertEqual(docs, list(c)) + self.assertEqual(len(docs), len(corpus)) + self.assertEqual(len(docs), len(corpus[:])) + self.assertEqual(len(docs[::2]), len(corpus[::2])) + + def _get_slice(corpus, slice_): + # assertRaises for python 2.6 takes a callable + return corpus[slice_] + + # make sure proper input validation for sliced corpora is done + self.assertRaises(ValueError, _get_slice, corpus, {1}) + self.assertRaises(ValueError, _get_slice, corpus, 1.0) + + # check sliced corpora that use fancy indexing + c = corpus[[1, 3, 4]] + self.assertEqual([d for i, d in enumerate(docs) if i in [1, 3, 4]], list(c)) + self.assertEqual([d for i, d in enumerate(docs) if i in [1, 3, 4]], list(c)) + self.assertEqual(len(corpus[[0, 1, -1]]), 3) + self.assertEqual(len(corpus[np.asarray([0, 1, -1])]), 3) + + # check that TransformedCorpus supports indexing when the underlying + # corpus does, and throws an error otherwise + corpus_ = TransformedCorpus(DummyTransformer(), corpus) + if hasattr(corpus, 'index') and corpus.index is not None: + self.assertEqual(corpus_[0][0][1], docs[0][0][1] + 1) + self.assertRaises(ValueError, _get_slice, corpus_, {1}) + transformed_docs = [val + 1 for i, d in enumerate(docs) for _, val in d if i in [1, 3, 4]] + self.assertEqual(transformed_docs, list(v for doc in corpus_[[1, 3, 4]] for _, v in doc)) + self.assertEqual(3, len(corpus_[[1, 3, 4]])) + else: + self.assertRaises(RuntimeError, _get_slice, corpus_, [1, 3, 4]) + self.assertRaises(RuntimeError, _get_slice, corpus_, {1}) + self.assertRaises(RuntimeError, _get_slice, corpus_, 1.0) + + +class TestMmCorpusWithIndex(CorpusTestCase): + def setUp(self): + self.corpus_class = mmcorpus.MmCorpus + self.corpus = self.corpus_class(datapath('test_mmcorpus_with_index.mm')) + self.file_extension = '.mm' + + def test_serialize_compressed(self): + # MmCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + def test_closed_file_object(self): + file_obj = open(datapath('testcorpus.mm')) + f = file_obj.closed + mmcorpus.MmCorpus(file_obj) + s = file_obj.closed + self.assertEqual(f, 0) + self.assertEqual(s, 0) + + @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see ') + def test_load(self): + self.assertEqual(self.corpus.num_docs, 9) + self.assertEqual(self.corpus.num_terms, 12) + self.assertEqual(self.corpus.num_nnz, 28) + + # confirm we can iterate and that document values match expected for first three docs + it = iter(self.corpus) + self.assertEqual(next(it), [(0, 1.0), (1, 1.0), (2, 1.0)]) + self.assertEqual(next(it), [(0, 1.0), (3, 1.0), (4, 1.0), (5, 1.0), (6, 1.0), (7, 1.0)]) + self.assertEqual(next(it), [(2, 1.0), (5, 1.0), (7, 1.0), (8, 1.0)]) + + # confirm that accessing document by index works + self.assertEqual(self.corpus[3], [(1, 1.0), (5, 2.0), (8, 1.0)]) + self.assertEqual(tuple(self.corpus.index), (97, 121, 169, 201, 225, 249, 258, 276, 303)) + + +class TestMmCorpusNoIndex(CorpusTestCase): + def setUp(self): + self.corpus_class = mmcorpus.MmCorpus + self.corpus = self.corpus_class(datapath('test_mmcorpus_no_index.mm')) + self.file_extension = '.mm' + + def test_serialize_compressed(self): + # MmCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + def test_load(self): + self.assertEqual(self.corpus.num_docs, 9) + self.assertEqual(self.corpus.num_terms, 12) + self.assertEqual(self.corpus.num_nnz, 28) + + # confirm we can iterate and that document values match expected for first three docs + it = iter(self.corpus) + self.assertEqual(next(it), [(0, 1.0), (1, 1.0), (2, 1.0)]) + self.assertEqual(next(it), []) + self.assertEqual(next(it), [(2, 0.42371910849), (5, 0.6625174), (7, 1.0), (8, 1.0)]) + + # confirm that accessing document by index fails + self.assertRaises(RuntimeError, lambda: self.corpus[3]) + + +class TestMmCorpusNoIndexGzip(CorpusTestCase): + def setUp(self): + self.corpus_class = mmcorpus.MmCorpus + self.corpus = self.corpus_class(datapath('test_mmcorpus_no_index.mm.gz')) + self.file_extension = '.mm' + + def test_serialize_compressed(self): + # MmCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + def test_load(self): + self.assertEqual(self.corpus.num_docs, 9) + self.assertEqual(self.corpus.num_terms, 12) + self.assertEqual(self.corpus.num_nnz, 28) + + # confirm we can iterate and that document values match expected for first three docs + it = iter(self.corpus) + self.assertEqual(next(it), [(0, 1.0), (1, 1.0), (2, 1.0)]) + self.assertEqual(next(it), []) + self.assertEqual(next(it), [(2, 0.42371910849), (5, 0.6625174), (7, 1.0), (8, 1.0)]) + + # confirm that accessing document by index fails + self.assertRaises(RuntimeError, lambda: self.corpus[3]) + + +class TestMmCorpusNoIndexBzip(CorpusTestCase): + def setUp(self): + self.corpus_class = mmcorpus.MmCorpus + self.corpus = self.corpus_class(datapath('test_mmcorpus_no_index.mm.bz2')) + self.file_extension = '.mm' + + def test_serialize_compressed(self): + # MmCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + def test_load(self): + self.assertEqual(self.corpus.num_docs, 9) + self.assertEqual(self.corpus.num_terms, 12) + self.assertEqual(self.corpus.num_nnz, 28) + + # confirm we can iterate and that document values match expected for first three docs + it = iter(self.corpus) + self.assertEqual(next(it), [(0, 1.0), (1, 1.0), (2, 1.0)]) + self.assertEqual(next(it), []) + self.assertEqual(next(it), [(2, 0.42371910849), (5, 0.6625174), (7, 1.0), (8, 1.0)]) + + # confirm that accessing document by index fails + self.assertRaises(RuntimeError, lambda: self.corpus[3]) + + +class TestMmCorpusCorrupt(CorpusTestCase): + def setUp(self): + self.corpus_class = mmcorpus.MmCorpus + self.corpus = self.corpus_class(datapath('test_mmcorpus_corrupt.mm')) + self.file_extension = '.mm' + + def test_serialize_compressed(self): + # MmCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + def test_load(self): + self.assertRaises(ValueError, lambda: [doc for doc in self.corpus]) + + +class TestMmCorpusOverflow(CorpusTestCase): + """ + Test to make sure cython mmreader doesn't overflow on large number of docs or terms + + """ + def setUp(self): + self.corpus_class = mmcorpus.MmCorpus + self.corpus = self.corpus_class(datapath('test_mmcorpus_overflow.mm')) + self.file_extension = '.mm' + + def test_serialize_compressed(self): + # MmCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + def test_load(self): + self.assertEqual(self.corpus.num_docs, 44270060) + self.assertEqual(self.corpus.num_terms, 500) + self.assertEqual(self.corpus.num_nnz, 22134988630) + + # confirm we can iterate and that document values match expected for first three docs + it = iter(self.corpus) + self.assertEqual(next(it)[:3], [(0, 0.3913027376444812), + (1, -0.07658791716226626), + (2, -0.020870794080588395)]) + self.assertEqual(next(it), []) + self.assertEqual(next(it), []) + + # confirm count of terms + count = 0 + for doc in self.corpus: + for term in doc: + count += 1 + + self.assertEqual(count, 12) + + # confirm that accessing document by index fails + self.assertRaises(RuntimeError, lambda: self.corpus[3]) + + +class TestSvmLightCorpus(CorpusTestCase): + def setUp(self): + self.corpus_class = svmlightcorpus.SvmLightCorpus + self.file_extension = '.svmlight' + + def test_serialization(self): + path = get_tmpfile("svml.corpus") + labels = [1] * len(common_corpus) + second_corpus = [(0, 1.0), (3, 1.0), (4, 1.0), (5, 1.0), (6, 1.0), (7, 1.0)] + self.corpus_class.serialize(path, common_corpus, labels=labels) + serialized_corpus = self.corpus_class(path) + self.assertEqual(serialized_corpus[1], second_corpus) + self.corpus_class.serialize(path, common_corpus, labels=np.array(labels)) + serialized_corpus = self.corpus_class(path) + self.assertEqual(serialized_corpus[1], second_corpus) + + +class TestBleiCorpus(CorpusTestCase): + def setUp(self): + self.corpus_class = bleicorpus.BleiCorpus + self.file_extension = '.blei' + + def test_save_format_for_dtm(self): + corpus = [[(1, 1.0)], [], [(0, 5.0), (2, 1.0)], []] + test_file = get_tmpfile('gensim_corpus.tst') + self.corpus_class.save_corpus(test_file, corpus) + with open(test_file) as f: + for line in f: + # unique_word_count index1:count1 index2:count2 ... indexn:count + tokens = line.split() + words_len = int(tokens[0]) + if words_len > 0: + tokens = tokens[1:] + else: + tokens = [] + self.assertEqual(words_len, len(tokens)) + for token in tokens: + word, count = token.split(':') + self.assertEqual(count, str(int(count))) + + +class TestLowCorpus(CorpusTestCase): + TEST_CORPUS = [[(1, 1)], [], [(0, 2), (2, 1)], []] + CORPUS_LINE = 'mom wash window window was washed' + + def setUp(self): + self.corpus_class = lowcorpus.LowCorpus + self.file_extension = '.low' + + def test_line2doc(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + id2word = {1: 'mom', 2: 'window'} + + corpus = self.corpus_class(fname, id2word=id2word) + + # should return all words in doc + corpus.use_wordids = False + self.assertEqual( + sorted(corpus.line2doc(self.CORPUS_LINE)), + [('mom', 1), ('was', 1), ('wash', 1), ('washed', 1), ('window', 2)]) + + # should return words in word2id + corpus.use_wordids = True + self.assertEqual( + sorted(corpus.line2doc(self.CORPUS_LINE)), + [(1, 1), (2, 2)]) + + +class TestUciCorpus(CorpusTestCase): + TEST_CORPUS = [[(1, 1)], [], [(0, 2), (2, 1)], []] + + def setUp(self): + self.corpus_class = ucicorpus.UciCorpus + self.file_extension = '.uci' + + def test_serialize_compressed(self): + # UciCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + +class TestMalletCorpus(TestLowCorpus): + TEST_CORPUS = [[(1, 1)], [], [(0, 2), (2, 1)], []] + CORPUS_LINE = '#3 lang mom wash window window was washed' + + def setUp(self): + self.corpus_class = malletcorpus.MalletCorpus + self.file_extension = '.mallet' + + def test_load_with_metadata(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + corpus = self.corpus_class(fname) + corpus.metadata = True + self.assertEqual(len(corpus), 9) + + docs = list(corpus) + self.assertEqual(len(docs), 9) + + for i, docmeta in enumerate(docs): + doc, metadata = docmeta + self.assertEqual(metadata[0], str(i + 1)) + self.assertEqual(metadata[1], 'en') + + def test_line2doc(self): + # case with metadata=False (by default) + super(TestMalletCorpus, self).test_line2doc() + + # case with metadata=True + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + id2word = {1: 'mom', 2: 'window'} + + corpus = self.corpus_class(fname, id2word=id2word, metadata=True) + + # should return all words in doc + corpus.use_wordids = False + doc, (docid, doclang) = corpus.line2doc(self.CORPUS_LINE) + self.assertEqual(docid, '#3') + self.assertEqual(doclang, 'lang') + self.assertEqual( + sorted(doc), + [('mom', 1), ('was', 1), ('wash', 1), ('washed', 1), ('window', 2)]) + + # should return words in word2id + corpus.use_wordids = True + doc, (docid, doclang) = corpus.line2doc(self.CORPUS_LINE) + + self.assertEqual(docid, '#3') + self.assertEqual(doclang, 'lang') + self.assertEqual( + sorted(doc), + [(1, 1), (2, 2)]) + + +class TestTextCorpus(CorpusTestCase): + + def setUp(self): + self.corpus_class = textcorpus.TextCorpus + self.file_extension = '.txt' + + def test_load_with_metadata(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + corpus = self.corpus_class(fname) + corpus.metadata = True + self.assertEqual(len(corpus), 9) + + docs = list(corpus) + self.assertEqual(len(docs), 9) + + for i, docmeta in enumerate(docs): + doc, metadata = docmeta + self.assertEqual(metadata[0], i) + + def test_default_preprocessing(self): + lines = [ + "Šéf chomutovských komunistů dostal poštou bílý prášek", + "this is a test for stopwords", + "zf tooth spaces " + ] + expected = [ + ['Sef', 'chomutovskych', 'komunistu', 'dostal', 'postou', 'bily', 'prasek'], + ['test', 'stopwords'], + ['tooth', 'spaces'] + ] + + corpus = self.corpus_from_lines(lines) + texts = list(corpus.get_texts()) + self.assertEqual(expected, texts) + + def corpus_from_lines(self, lines): + fpath = tempfile.mktemp() + with codecs.open(fpath, 'w', encoding='utf8') as f: + f.write('\n'.join(lines)) + + return self.corpus_class(fpath) + + def test_sample_text(self): + lines = ["document%d" % i for i in range(10)] + corpus = self.corpus_from_lines(lines) + corpus.tokenizer = lambda text: text.split() + docs = [doc for doc in corpus.get_texts()] + + sample1 = list(corpus.sample_texts(1)) + self.assertEqual(len(sample1), 1) + self.assertIn(sample1[0], docs) + + sample2 = list(corpus.sample_texts(len(lines))) + self.assertEqual(len(sample2), len(corpus)) + for i in range(len(corpus)): + self.assertEqual(sample2[i], ["document%s" % i]) + + with self.assertRaises(ValueError): + list(corpus.sample_texts(len(corpus) + 1)) + + with self.assertRaises(ValueError): + list(corpus.sample_texts(-1)) + + def test_sample_text_length(self): + lines = ["document%d" % i for i in range(10)] + corpus = self.corpus_from_lines(lines) + corpus.tokenizer = lambda text: text.split() + + sample1 = list(corpus.sample_texts(1, length=1)) + self.assertEqual(sample1[0], ["document0"]) + + sample2 = list(corpus.sample_texts(2, length=2)) + self.assertEqual(sample2[0], ["document0"]) + self.assertEqual(sample2[1], ["document1"]) + + def test_sample_text_seed(self): + lines = ["document%d" % i for i in range(10)] + corpus = self.corpus_from_lines(lines) + + sample1 = list(corpus.sample_texts(5, seed=42)) + sample2 = list(corpus.sample_texts(5, seed=42)) + self.assertEqual(sample1, sample2) + + def test_save(self): + pass + + def test_serialize(self): + pass + + def test_serialize_compressed(self): + pass + + def test_indexing(self): + pass + + +# Needed for the test_simple_tokenizer and test_list_tokenizers are the TestWikiCorpus class. +# Cannot be nested due to serializing. +def simple_tokenize(content, token_min_len=2, token_max_len=15, lower=True): + return [ + token for token in (content.lower() if lower else content).split() + if token_min_len <= len(token) <= token_max_len] + + +# Needed for the test_custom_tokenizer and test_list_tokenizers are the TestWikiCorpus class. +# Cannot be nested due to serializing. +def custom_tokenizer(content, token_min_len=2, token_max_len=15, lower=True): + return [ + to_unicode(token.lower()) if lower else to_unicode(token) for token in content.split() + if token_min_len <= len(token) <= token_max_len and not token.startswith('_') + ] + + +class TestWikiCorpus(TestTextCorpus): + def setUp(self): + self.corpus_class = wikicorpus.WikiCorpus + self.file_extension = '.xml.bz2' + self.fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + self.enwiki = datapath('enwiki-latest-pages-articles1.xml-p000000010p000030302-shortened.bz2') + + def test_default_preprocessing(self): + expected = ['computer', 'human', 'interface'] + corpus = self.corpus_class(self.fname, article_min_tokens=0) + first_text = next(corpus.get_texts()) + self.assertEqual(expected, first_text) + + def test_len(self): + # When there is no min_token limit all 9 articles must be registered. + corpus = self.corpus_class(self.fname, article_min_tokens=0) + all_articles = corpus.get_texts() + assert (len(list(all_articles)) == 9) + + # With a huge min_token limit, all articles should be filtered out. + corpus = self.corpus_class(self.fname, article_min_tokens=100000) + all_articles = corpus.get_texts() + assert (len(list(all_articles)) == 0) + + def test_load_with_metadata(self): + corpus = self.corpus_class(self.fname, article_min_tokens=0) + corpus.metadata = True + self.assertEqual(len(corpus), 9) + + docs = list(corpus) + self.assertEqual(len(docs), 9) + + for i, docmeta in enumerate(docs): + doc, metadata = docmeta + article_no = i + 1 # Counting IDs from 1 + self.assertEqual(metadata[0], str(article_no)) + self.assertEqual(metadata[1], 'Article%d' % article_no) + + def test_load(self): + corpus = self.corpus_class(self.fname, article_min_tokens=0) + + docs = list(corpus) + # the deerwester corpus always has nine documents + self.assertEqual(len(docs), 9) + + def test_first_element(self): + """ + First two articles in this sample are + 1) anarchism + 2) autism + """ + corpus = self.corpus_class(self.enwiki, processes=1) + + texts = corpus.get_texts() + self.assertTrue(u'anarchism' in next(texts)) + self.assertTrue(u'autism' in next(texts)) + + def test_unicode_element(self): + """ + First unicode article in this sample is + 1) папа + """ + bgwiki = datapath('bgwiki-latest-pages-articles-shortened.xml.bz2') + corpus = self.corpus_class(bgwiki) + texts = corpus.get_texts() + self.assertTrue(u'папа' in next(texts)) + + def test_simple_tokenizer(self): + """ + define a simple tokenizer function and use it + """ + wc = self.corpus_class(self.enwiki, processes=1, tokenizer_func=simple_tokenizer, + token_max_len=16, token_min_len=1, lower=False) + row = wc.get_texts() + list_tokens = next(row) + self.assertTrue(u'Anarchism' in list_tokens) + self.assertTrue(u'collectivization' in list_tokens) + self.assertTrue(u'a' in list_tokens) + self.assertTrue(u'i.e.' in list_tokens) + + def test_custom_tokenizer(self): + """ + define a custom tokenizer function and use it + """ + wc = self.corpus_class(self.enwiki, processes=1, tokenizer_func=custom_tokenizer, + token_max_len=16, token_min_len=1, lower=False) + row = wc.get_texts() + list_tokens = next(row) + self.assertTrue(u'Anarchism' in list_tokens) + self.assertTrue(u'collectivization' in list_tokens) + self.assertTrue(u'a' in list_tokens) + self.assertTrue(u'i.e.' in list_tokens) + + def test_list_tokenizers(self): + """ + define a list containing two tokenizers functions (simple and custom) and use it + """ + wc = self.corpus_class(self.enwiki, processes=1, tokenizer_func=[simple_tokenizer, custom_tokenizer], + token_max_len=16, token_min_len=1, lower=False) + row = wc.get_texts() + list_tokens = next(row) + self.assertTrue(u'Anarchism' in list_tokens) + self.assertTrue(u'collectivization' in list_tokens) + self.assertTrue(u'a' in list_tokens) + self.assertTrue(u'i.e.' in list_tokens) + + def test_lower_case_set_true(self): + """ + Set the parameter lower to True and check that upper case 'Anarchism' token doesnt exist + """ + corpus = self.corpus_class(self.enwiki, processes=1, lower=True) + row = corpus.get_texts() + list_tokens = next(row) + self.assertTrue(u'Anarchism' not in list_tokens) + self.assertTrue(u'anarchism' in list_tokens) + + def test_lower_case_set_false(self): + """ + Set the parameter lower to False and check that upper case Anarchism' token exists + """ + corpus = self.corpus_class(self.enwiki, processes=1, lower=False) + row = corpus.get_texts() + list_tokens = next(row) + self.assertTrue(u'Anarchism' in list_tokens) + self.assertTrue(u'anarchism' in list_tokens) + + def test_min_token_len_not_set(self): + """ + Don't set the parameter token_min_len and check that 'a' as a token doesn't exist + Default token_min_len=2 + """ + corpus = self.corpus_class(self.enwiki, processes=1) + self.assertTrue(u'a' not in next(corpus.get_texts())) + + def test_min_token_len_set(self): + """ + Set the parameter token_min_len to 1 and check that 'a' as a token exists + """ + corpus = self.corpus_class(self.enwiki, processes=1, token_min_len=1) + self.assertTrue(u'a' in next(corpus.get_texts())) + + def test_max_token_len_not_set(self): + """ + Don't set the parameter token_max_len and check that 'collectivisation' as a token doesn't exist + Default token_max_len=15 + """ + corpus = self.corpus_class(self.enwiki, processes=1) + self.assertTrue(u'collectivization' not in next(corpus.get_texts())) + + def test_max_token_len_set(self): + """ + Set the parameter token_max_len to 16 and check that 'collectivisation' as a token exists + """ + corpus = self.corpus_class(self.enwiki, processes=1, token_max_len=16) + self.assertTrue(u'collectivization' in next(corpus.get_texts())) + + def test_removed_table_markup(self): + """ + Check if all the table markup has been removed. + """ + enwiki_file = datapath('enwiki-table-markup.xml.bz2') + corpus = self.corpus_class(enwiki_file) + texts = corpus.get_texts() + table_markup = ["style", "class", "border", "cellspacing", "cellpadding", "colspan", "rowspan"] + for text in texts: + for word in table_markup: + self.assertTrue(word not in text) + + def test_get_stream(self): + wiki = self.corpus_class(self.enwiki) + sample_text_wiki = next(wiki.getstream()).decode()[1:14] + self.assertEqual(sample_text_wiki, "mediawiki xml") + + # #TODO: sporadic failure to be investigated + # def test_get_texts_returns_generator_of_lists(self): + # corpus = self.corpus_class(self.enwiki) + # l = corpus.get_texts() + # self.assertEqual(type(l), types.GeneratorType) + # first = next(l) + # self.assertEqual(type(first), list) + # self.assertTrue(isinstance(first[0], bytes) or isinstance(first[0], str)) + + def test_sample_text(self): + # Cannot instantiate WikiCorpus from lines + pass + + def test_sample_text_length(self): + # Cannot instantiate WikiCorpus from lines + pass + + def test_sample_text_seed(self): + # Cannot instantiate WikiCorpus from lines + pass + + def test_empty_input(self): + # An empty file is not legit XML + pass + + def test_custom_filterfunction(self): + def reject_all(elem, *args, **kwargs): + return False + corpus = self.corpus_class(self.enwiki, filter_articles=reject_all) + texts = corpus.get_texts() + self.assertFalse(any(texts)) + + def keep_some(elem, title, *args, **kwargs): + return title[0] == 'C' + corpus = self.corpus_class(self.enwiki, filter_articles=reject_all) + corpus.metadata = True + texts = corpus.get_texts() + for text, (pageid, title) in texts: + self.assertEquals(title[0], 'C') + + +class TestTextDirectoryCorpus(unittest.TestCase): + + def write_one_level(self, *args): + if not args: + args = ('doc1', 'doc2') + dirpath = tempfile.mkdtemp() + self.write_docs_to_directory(dirpath, *args) + return dirpath + + def write_docs_to_directory(self, dirpath, *args): + for doc_num, name in enumerate(args): + with open(os.path.join(dirpath, name), 'w') as f: + f.write('document %d content' % doc_num) + + def test_one_level_directory(self): + dirpath = self.write_one_level() + + corpus = textcorpus.TextDirectoryCorpus(dirpath) + self.assertEqual(len(corpus), 2) + docs = list(corpus) + self.assertEqual(len(docs), 2) + + def write_two_levels(self): + dirpath = self.write_one_level() + next_level = os.path.join(dirpath, 'level_two') + os.mkdir(next_level) + self.write_docs_to_directory(next_level, 'doc1', 'doc2') + return dirpath, next_level + + def test_two_level_directory(self): + dirpath, next_level = self.write_two_levels() + + corpus = textcorpus.TextDirectoryCorpus(dirpath) + self.assertEqual(len(corpus), 4) + docs = list(corpus) + self.assertEqual(len(docs), 4) + + corpus = textcorpus.TextDirectoryCorpus(dirpath, min_depth=1) + self.assertEqual(len(corpus), 2) + docs = list(corpus) + self.assertEqual(len(docs), 2) + + corpus = textcorpus.TextDirectoryCorpus(dirpath, max_depth=0) + self.assertEqual(len(corpus), 2) + docs = list(corpus) + self.assertEqual(len(docs), 2) + + def test_filename_filtering(self): + dirpath = self.write_one_level('test1.log', 'test1.txt', 'test2.log', 'other1.log') + corpus = textcorpus.TextDirectoryCorpus(dirpath, pattern=r"test.*\.log") + filenames = list(corpus.iter_filepaths()) + expected = [os.path.join(dirpath, name) for name in ('test1.log', 'test2.log')] + self.assertEqual(sorted(expected), sorted(filenames)) + + corpus.pattern = ".*.txt" + filenames = list(corpus.iter_filepaths()) + expected = [os.path.join(dirpath, 'test1.txt')] + self.assertEqual(expected, filenames) + + corpus.pattern = None + corpus.exclude_pattern = ".*.log" + filenames = list(corpus.iter_filepaths()) + self.assertEqual(expected, filenames) + + def test_lines_are_documents(self): + dirpath = tempfile.mkdtemp() + lines = ['doc%d text' % i for i in range(5)] + fpath = os.path.join(dirpath, 'test_file.txt') + with open(fpath, 'w') as f: + f.write('\n'.join(lines)) + + corpus = textcorpus.TextDirectoryCorpus(dirpath, lines_are_documents=True) + docs = [doc for doc in corpus.getstream()] + self.assertEqual(len(lines), corpus.length) # should have cached + self.assertEqual(lines, docs) + + corpus.lines_are_documents = False + docs = [doc for doc in corpus.getstream()] + self.assertEqual(1, corpus.length) + self.assertEqual('\n'.join(lines), docs[0]) + + def test_non_trivial_structure(self): + """Test with non-trivial directory structure, shown below: + . + ├── 0.txt + ├── a_folder + │ └── 1.txt + └── b_folder + ├── 2.txt + ├── 3.txt + └── c_folder + └── 4.txt + """ + dirpath = tempfile.mkdtemp() + self.write_docs_to_directory(dirpath, '0.txt') + + a_folder = os.path.join(dirpath, 'a_folder') + os.mkdir(a_folder) + self.write_docs_to_directory(a_folder, '1.txt') + + b_folder = os.path.join(dirpath, 'b_folder') + os.mkdir(b_folder) + self.write_docs_to_directory(b_folder, '2.txt', '3.txt') + + c_folder = os.path.join(b_folder, 'c_folder') + os.mkdir(c_folder) + self.write_docs_to_directory(c_folder, '4.txt') + + corpus = textcorpus.TextDirectoryCorpus(dirpath) + filenames = list(corpus.iter_filepaths()) + base_names = sorted(name[len(dirpath) + 1:] for name in filenames) + expected = sorted([ + '0.txt', + 'a_folder/1.txt', + 'b_folder/2.txt', + 'b_folder/3.txt', + 'b_folder/c_folder/4.txt' + ]) + expected = [os.path.normpath(path) for path in expected] + self.assertEqual(expected, base_names) + + corpus.max_depth = 1 + self.assertEqual(expected[:-1], base_names[:-1]) + + corpus.min_depth = 1 + self.assertEqual(expected[2:-1], base_names[2:-1]) + + corpus.max_depth = 0 + self.assertEqual(expected[2:], base_names[2:]) + + corpus.pattern = "4.*" + self.assertEqual(expected[-1], base_names[-1]) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main()