Skip to content

Commit

Permalink
FIX: Handle short strings better
Browse files Browse the repository at this point in the history
Input strings that are shorter than the longest string in the dataset
were force-searched literally. Removed this condition. See #40.
  • Loading branch information
jaidevd committed Feb 11, 2020
1 parent 08c54ef commit 0a21887
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 14 deletions.
11 changes: 7 additions & 4 deletions nlg/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ def _sort_search_results(items, priorities=SEARCH_PRIORITIES):
Prioritized search results - for each {token: search_matches} pair, sort
search_matches such that a higher priority search result is enabled.
"""
match_ix = [[p.items() <= item.items() for p in priorities] for item in items]
min_match = [m.index(True) for m in match_ix]
items[min_match.index(min(min_match))]['enabled'] = True
if len(items) > 1:
match_ix = [[p.items() <= item.items() for p in priorities] for item in items]
min_match = [m.index(True) for m in match_ix]
items[min_match.index(min(min_match))]['enabled'] = True
else:
items[0]['enabled'] = True
return items


Expand Down Expand Up @@ -230,6 +233,7 @@ def search(self, text, colname_fmt='df.columns[{}]',
the source dataframe, and values are a list of locations in the df
where they are found.
"""
self.search_nes(text)
if len(text.text) <= _df_maxlen(self.df):
for i in _text_search_array(text.text, self.df.columns):
self.results[text] = {'location': 'colname', 'tmpl': colname_fmt.format(i),
Expand All @@ -242,7 +246,6 @@ def search(self, text, colname_fmt='df.columns[{}]',
'type': 'doc'}

else:
self.search_nes(text)
for token, ix in self.search_columns(text, **kwargs).items():
ix = utils.sanitize_indices(self.df.shape, ix, 1)
self.results[token] = {'location': 'colname', 'tmpl': colname_fmt.format(ix),
Expand Down
20 changes: 12 additions & 8 deletions nlg/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class TestSearch(unittest.TestCase):
def setUpClass(cls):
fpath = op.join(op.dirname(__file__), "data", "actors.csv")
cls.df = pd.read_csv(fpath, encoding='utf-8')
fpath = op.join(op.dirname(__file__), "data", "imdb_ratings.csv")
cls.imdb = pd.read_csv(fpath, encoding='utf-8')

def test_dfsearches(self):
x = search.DFSearchResults()
Expand All @@ -118,7 +120,6 @@ def test_dfsearches(self):
x['hello'] = 'underworld'
self.assertDictEqual(x, {'hello': ['world', 'underworld']})

# @unittest.skip("Temporary")
def test_search_args(self):
args = utils.sanitize_fh_args({"_sort": ["-votes"]}, self.df)
doc = nlp("James Stewart is the top voted actor.")
Expand All @@ -145,9 +146,7 @@ def test_search_args_literal(self):
"type": "token"}})

def test_templatize(self):
fpath = op.join(op.dirname(__file__), "data", "actors.csv")
df = pd.read_csv(fpath, encoding='utf-8')
df.sort_values("votes", ascending=False, inplace=True)
df = self.df.sort_values("votes", ascending=False)
df.reset_index(inplace=True, drop=True)

doc = nlp("""
Expand Down Expand Up @@ -190,7 +189,6 @@ def test_templatize(self):
)

def test_search_sort(self):

results = [
{'tmpl': 'df.loc[0, "name"]', 'type': 'ne', 'location': 'cell'},
{'tmpl': 'df.columns[0]', 'type': 'token', 'location': 'colname'},
Expand Down Expand Up @@ -236,17 +234,23 @@ def test_single_entity_search(self):
self.assertEqual(variable.template, '{{ df["name"].iloc[0] }}')

def test_literal_search(self):
df = pd.read_csv(
op.join(op.dirname(__file__), 'data', 'imdb_ratings.csv'), encoding='utf8')
texts = ['How I Met Your Mother', 'Sherlock', 'Dexter', 'Breaking Bad']
for t in texts:
doc = nlp(t)
nugget = search.templatize(doc, {}, df)
nugget = search.templatize(doc, {}, self.imdb)
self.assertEqual(len(nugget.tokenmap), 1)
for token, variable in nugget.tokenmap.items():
self.assertEqual(token.text, t)
self.assertRegex(nugget.template, r'{{ df\["name"\].iloc\[-*\d+\] }}')

def test_search_short_strings(self):
# Check strings that are shorter than the max length of the df,
# but still not a literal match
nugget = search.templatize(nlp('Dexter is a good show'), {}, self.imdb)
self.assertEqual(len(nugget.tokenmap), 1)
token, variable = nugget.tokenmap.popitem()
self.assertRegex(variable.enabled_source['tmpl'], r'df\["name"\].iloc\[-*\d+\]')


if __name__ == "__main__":
unittest.main()
11 changes: 9 additions & 2 deletions nlg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re

import pandas as pd
from spacy.tokens import Token
from spacy.tokens import Token, Doc, Span
from tornado.template import Template

from gramex.data import filter as gfilter # NOQA: F401
Expand Down Expand Up @@ -111,12 +111,19 @@ def __call__(self, func):

def is_overlap(x, y):
"""Whether the token x is contained within any span in the sequence y."""
if len(y) == 0:
return False
if isinstance(x, Token):
if x.pos_ == "NUM":
return False
elif 'NUM' in [c.pos_ for c in x]:
return False
return any([x.text in yy.text for yy in y])
if len(y) > 1:
return any([x.text in yy.text for yy in y])
y = y.pop()
if isinstance(x, (Token, Span)) and isinstance(y, Doc):
return x.doc == y
return False


def unoverlap(tokens):
Expand Down

0 comments on commit 0a21887

Please sign in to comment.