diff --git a/nlg/search.py b/nlg/search.py index 6574b2d..716f6f1 100644 --- a/nlg/search.py +++ b/nlg/search.py @@ -46,9 +46,12 @@ def _sort_search_results(items, priorities=SEARCH_PRIORITIES): Prioritized search results - for each {token: search_matches} pair, sort search_matches such that a higher priority search result is enabled. """ - match_ix = [[p.items() <= item.items() for p in priorities] for item in items] - min_match = [m.index(True) for m in match_ix] - items[min_match.index(min(min_match))]['enabled'] = True + if len(items) > 1: + match_ix = [[p.items() <= item.items() for p in priorities] for item in items] + min_match = [m.index(True) for m in match_ix] + items[min_match.index(min(min_match))]['enabled'] = True + else: + items[0]['enabled'] = True return items @@ -230,6 +233,7 @@ def search(self, text, colname_fmt='df.columns[{}]', the source dataframe, and values are a list of locations in the df where they are found. """ + self.search_nes(text) if len(text.text) <= _df_maxlen(self.df): for i in _text_search_array(text.text, self.df.columns): self.results[text] = {'location': 'colname', 'tmpl': colname_fmt.format(i), @@ -242,7 +246,6 @@ def search(self, text, colname_fmt='df.columns[{}]', 'type': 'doc'} else: - self.search_nes(text) for token, ix in self.search_columns(text, **kwargs).items(): ix = utils.sanitize_indices(self.df.shape, ix, 1) self.results[token] = {'location': 'colname', 'tmpl': colname_fmt.format(ix), diff --git a/nlg/tests/test_search.py b/nlg/tests/test_search.py index 158a94c..3a5aaa7 100644 --- a/nlg/tests/test_search.py +++ b/nlg/tests/test_search.py @@ -107,6 +107,8 @@ class TestSearch(unittest.TestCase): def setUpClass(cls): fpath = op.join(op.dirname(__file__), "data", "actors.csv") cls.df = pd.read_csv(fpath, encoding='utf-8') + fpath = op.join(op.dirname(__file__), "data", "imdb_ratings.csv") + cls.imdb = pd.read_csv(fpath, encoding='utf-8') def test_dfsearches(self): x = search.DFSearchResults() @@ -118,7 +120,6 @@ def test_dfsearches(self): x['hello'] = 'underworld' self.assertDictEqual(x, {'hello': ['world', 'underworld']}) - # @unittest.skip("Temporary") def test_search_args(self): args = utils.sanitize_fh_args({"_sort": ["-votes"]}, self.df) doc = nlp("James Stewart is the top voted actor.") @@ -145,9 +146,7 @@ def test_search_args_literal(self): "type": "token"}}) def test_templatize(self): - fpath = op.join(op.dirname(__file__), "data", "actors.csv") - df = pd.read_csv(fpath, encoding='utf-8') - df.sort_values("votes", ascending=False, inplace=True) + df = self.df.sort_values("votes", ascending=False) df.reset_index(inplace=True, drop=True) doc = nlp(""" @@ -190,7 +189,6 @@ def test_templatize(self): ) def test_search_sort(self): - results = [ {'tmpl': 'df.loc[0, "name"]', 'type': 'ne', 'location': 'cell'}, {'tmpl': 'df.columns[0]', 'type': 'token', 'location': 'colname'}, @@ -236,17 +234,23 @@ def test_single_entity_search(self): self.assertEqual(variable.template, '{{ df["name"].iloc[0] }}') def test_literal_search(self): - df = pd.read_csv( - op.join(op.dirname(__file__), 'data', 'imdb_ratings.csv'), encoding='utf8') texts = ['How I Met Your Mother', 'Sherlock', 'Dexter', 'Breaking Bad'] for t in texts: doc = nlp(t) - nugget = search.templatize(doc, {}, df) + nugget = search.templatize(doc, {}, self.imdb) self.assertEqual(len(nugget.tokenmap), 1) for token, variable in nugget.tokenmap.items(): self.assertEqual(token.text, t) self.assertRegex(nugget.template, r'{{ df\["name"\].iloc\[-*\d+\] }}') + def test_search_short_strings(self): + # Check strings that are shorter than the max length of the df, + # but still not a literal match + nugget = search.templatize(nlp('Dexter is a good show'), {}, self.imdb) + self.assertEqual(len(nugget.tokenmap), 1) + token, variable = nugget.tokenmap.popitem() + self.assertRegex(variable.enabled_source['tmpl'], r'df\["name"\].iloc\[-*\d+\]') + if __name__ == "__main__": unittest.main() diff --git a/nlg/utils.py b/nlg/utils.py index 8b5780d..50cd0cc 100644 --- a/nlg/utils.py +++ b/nlg/utils.py @@ -8,7 +8,7 @@ import re import pandas as pd -from spacy.tokens import Token +from spacy.tokens import Token, Doc, Span from tornado.template import Template from gramex.data import filter as gfilter # NOQA: F401 @@ -111,12 +111,19 @@ def __call__(self, func): def is_overlap(x, y): """Whether the token x is contained within any span in the sequence y.""" + if len(y) == 0: + return False if isinstance(x, Token): if x.pos_ == "NUM": return False elif 'NUM' in [c.pos_ for c in x]: return False - return any([x.text in yy.text for yy in y]) + if len(y) > 1: + return any([x.text in yy.text for yy in y]) + y = y.pop() + if isinstance(x, (Token, Span)) and isinstance(y, Doc): + return x.doc == y + return False def unoverlap(tokens):