diff --git a/spellchecker/info.py b/spellchecker/info.py index 26de3b9..4ff8196 100644 --- a/spellchecker/info.py +++ b/spellchecker/info.py @@ -5,7 +5,7 @@ __maintainer__ = "Tyler Barrus" __email__ = "barrust@gmail.com" __license__ = "MIT" -__version__ = "0.5.1" +__version__ = "0.5.2" __credits__ = ["Peter Norvig"] __url__ = "https://github.com/barrust/pyspellchecker" __bugtrack_url__ = "{0}/issues".format(__url__) diff --git a/spellchecker/spellchecker.py b/spellchecker/spellchecker.py index a32cf10..589c77b 100644 --- a/spellchecker/spellchecker.py +++ b/spellchecker/spellchecker.py @@ -7,7 +7,7 @@ import string from collections import Counter -from .utils import load_file, write_file, _parse_into_words +from .utils import load_file, write_file, _parse_into_words, ENSURE_UNICODE class SpellChecker(object): @@ -62,10 +62,12 @@ def __init__( def __contains__(self, key): """ setup easier known checks """ + key = ENSURE_UNICODE(key) return key in self._word_frequency def __getitem__(self, key): """ setup easier frequency checks """ + key = ENSURE_UNICODE(key) return self._word_frequency[key] @property @@ -105,6 +107,7 @@ def split_words(self, text): text (str): The text to split into individual words Returns: list(str): A listing of all words in the provided text """ + text = ENSURE_UNICODE(text) return self._tokenizer(text) def export(self, filepath, encoding="utf-8", gzipped=True): @@ -131,6 +134,7 @@ def word_probability(self, word, total_words=None): float: The probability that the word is the correct word """ if total_words is None: total_words = self._word_frequency.total_words + word = ENSURE_UNICODE(word) return self._word_frequency.dictionary[word] / total_words def correction(self, word): @@ -140,6 +144,7 @@ def correction(self, word): word (str): The word to correct Returns: str: The most likely candidate """ + word = ENSURE_UNICODE(word) candidates = list(self.candidates(word)) return max(sorted(candidates), key=self.word_probability) @@ -151,6 +156,7 @@ def candidates(self, word): word (str): The word for which to calculate candidate spellings Returns: set: The set of words that are possible candidates """ + word = ENSURE_UNICODE(word) if self.known([word]): # short-cut if word is correct already return {word} # get edit distance 1... @@ -174,6 +180,7 @@ def known(self, words): Returns: set: The set of those words from the input that are in the \ corpus """ + words = [ENSURE_UNICODE(w) for w in words] tmp = [w if self._case_sensitive else w.lower() for w in words] return set( w @@ -191,6 +198,7 @@ def unknown(self, words): Returns: set: The set of those words from the input that are not in \ the corpus """ + words = [ENSURE_UNICODE(w) for w in words] tmp = [ w if self._case_sensitive else w.lower() for w in words @@ -207,7 +215,7 @@ def edit_distance_1(self, word): Returns: set: The set of strings that are edit distance one from the \ provided word """ - word = word.lower() + word = ENSURE_UNICODE(word).lower() if self._check_if_should_check(word) is False: return {word} letters = self._word_frequency.letters @@ -227,7 +235,7 @@ def edit_distance_2(self, word): Returns: set: The set of strings that are edit distance two from the \ provided word """ - word = word.lower() + word = ENSURE_UNICODE(word).lower() return [ e2 for e1 in self.edit_distance_1(word) for e2 in self.edit_distance_1(e1) ] @@ -241,8 +249,13 @@ def __edit_distance_alt(self, words): Returns: set: The set of strings that are edit distance two from the \ provided words """ - words = [word.lower() for word in words] - return [e2 for e1 in words for e2 in self.edit_distance_1(e1)] + words = [ENSURE_UNICODE(w) for w in words] + tmp = [ + w if self._case_sensitive else w.lower() + for w in words + if self._check_if_should_check(w) + ] + return [e2 for e1 in tmp for e2 in self.edit_distance_1(e1)] @staticmethod def _check_if_should_check(word): @@ -283,11 +296,13 @@ def __init__(self, tokenizer=None, case_sensitive=False): def __contains__(self, key): """ turn on contains """ + key = ENSURE_UNICODE(key) key = key if self._case_sensitive else key.lower() return key in self._dictionary def __getitem__(self, key): """ turn on getitem """ + key = ENSURE_UNICODE(key) key = key if self._case_sensitive else key.lower() return self._dictionary[key] @@ -298,6 +313,7 @@ def pop(self, key, default=None): Args: key (str): The key to remove default (obj): The value to return if key is not present """ + key = ENSURE_UNICODE(key) key = key if self._case_sensitive else key.lower() return self._dictionary.pop(key, default) @@ -344,6 +360,7 @@ def tokenize(self, text): str: The next `word` in the tokenized string Note: This is the same as the `spellchecker.split_words()` """ + text = ENSURE_UNICODE(text) for word in self._tokenizer(text): yield word if self._case_sensitive else word.lower() @@ -408,6 +425,7 @@ def load_text(self, text, tokenizer=None): text (str): The text to be loaded tokenizer (function): The function to use to tokenize a string """ + text = ENSURE_UNICODE(text) if tokenizer: words = [x if self._case_sensitive else x.lower() for x in tokenizer(text)] else: @@ -421,6 +439,7 @@ def load_words(self, words): Args: words (list): The list of words to be loaded """ + words = [ENSURE_UNICODE(w) for w in words] self._dictionary.update( [word if self._case_sensitive else word.lower() for word in words] ) @@ -431,6 +450,7 @@ def add(self, word): Args: word (str): The word to add """ + word = ENSURE_UNICODE(word) self.load_words([word]) def remove_words(self, words): @@ -438,6 +458,7 @@ def remove_words(self, words): Args: words (list): The list of words to remove """ + words = [ENSURE_UNICODE(w) for w in words] for word in words: self._dictionary.pop(word if self._case_sensitive else word.lower()) self._update_dictionary() @@ -447,6 +468,7 @@ def remove(self, word): Args: word (str): The word to remove """ + word = ENSURE_UNICODE(word) self._dictionary.pop(word if self._case_sensitive else word.lower()) self._update_dictionary() diff --git a/spellchecker/utils.py b/spellchecker/utils.py index 4abc1f3..aefcc10 100644 --- a/spellchecker/utils.py +++ b/spellchecker/utils.py @@ -9,11 +9,22 @@ READMODE = 'rb' WRITEMODE = 'wb' OPEN = io.open # hijack this + + def ENSURE_UNICODE(s, encoding='utf-8'): + if isinstance(s, str): + return s.decode(encoding) + return s + else: READMODE = 'rt' WRITEMODE = 'wt' OPEN = open + def ENSURE_UNICODE(s, encoding='utf-8'): + if isinstance(s, bytes): + return s.decode(encoding) + return s + @contextlib.contextmanager def __gzip_read(filename, mode='rb', encoding='UTF-8'): diff --git a/tests/spellchecker_test.py b/tests/spellchecker_test.py index 97f1512..b200d6b 100644 --- a/tests/spellchecker_test.py +++ b/tests/spellchecker_test.py @@ -36,13 +36,13 @@ def test_candidates(self): self.assertEqual(spell.candidates('manasaeds'), {'manasaeds'}) def test_words(self): - ''' rest the parsing of words ''' + ''' test the parsing of words ''' spell = SpellChecker() res = ['this', 'is', 'a', 'test', 'of', 'this'] self.assertEqual(spell.split_words('This is a test of this'), res) def test_words_more_complete(self): - ''' rest the parsing of words ''' + ''' test the parsing of words ''' spell = SpellChecker() res = ['this', 'is', 'a', 'test', 'of', 'the', 'word', 'parser', 'it', 'should', 'work', 'correctly'] self.assertEqual(spell.split_words('This is a test of the word parser. It should work correctly!!!'), res) @@ -368,3 +368,15 @@ def tokens(txt): self.assertFalse('awesome' in spell) self.assertTrue(spell['whale']) self.assertTrue('sea.' in spell) + + def test_bytes_input(self): + """ Test using bytes instead of unicode as input """ + + var = b"bike" + + here = os.path.dirname(__file__) + filepath = '{}/resources/small_dictionary.json'.format(here) + spell = SpellChecker(language=None, local_dictionary=filepath) + + self.assertTrue(var in spell) + self.assertEqual(spell[var], 60)