diff --git a/polyfuzz/models/_tfidf.py b/polyfuzz/models/_tfidf.py index e065f11..243b1ba 100644 --- a/polyfuzz/models/_tfidf.py +++ b/polyfuzz/models/_tfidf.py @@ -51,7 +51,8 @@ def __init__(self, min_similarity: float = 0.75, top_n: int = 1, cosine_method: str = "sparse", - model_id: str = None): + model_id: str = None, + use_word_grams: bool = False): super().__init__(model_id) self.type = "TF-IDF" self.n_gram_range = n_gram_range @@ -61,6 +62,7 @@ def __init__(self, self.top_n = top_n self.vectorizer = None self.tf_idf_to = None + self.use_word_grams = use_word_grams def match(self, from_list: List[str], @@ -125,10 +127,17 @@ def _create_ngrams(self, string: str) -> List[str]: string = _clean_string(string) result = [] - for n in range(self.n_gram_range[0], self.n_gram_range[1]+1): - ngrams = zip(*[string[i:] for i in range(n)]) - ngrams = [''.join(ngram) for ngram in ngrams if ' ' not in ngram] - result.extend(ngrams) + if self.use_word_grams: + tokens = [token for token in string.split(" ") if token != ""] + for n in range(self.n_gram_range[0], self.n_gram_range[1]+1): + ngrams = zip(*[tokens[i:] for i in range(n)]) + ngrams = [''.join(ngram) for ngram in ngrams if ' ' not in ngram] + result.extend(ngrams) + else: + for n in range(self.n_gram_range[0], self.n_gram_range[1]+1): + ngrams = zip(*[string[i:] for i in range(n)]) + ngrams = [''.join(ngram) for ngram in ngrams if ' ' not in ngram] + result.extend(ngrams) return result