Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Handle short strings better #44

Merged
merged 1 commit into from
Feb 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions nlg/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ def _sort_search_results(items, priorities=SEARCH_PRIORITIES):
Prioritized search results - for each {token: search_matches} pair, sort
search_matches such that a higher priority search result is enabled.
"""
match_ix = [[p.items() <= item.items() for p in priorities] for item in items]
min_match = [m.index(True) for m in match_ix]
items[min_match.index(min(min_match))]['enabled'] = True
if len(items) > 1:
match_ix = [[p.items() <= item.items() for p in priorities] for item in items]
min_match = [m.index(True) for m in match_ix]
items[min_match.index(min(min_match))]['enabled'] = True
else:
items[0]['enabled'] = True
return items


Expand Down Expand Up @@ -230,6 +233,7 @@ def search(self, text, colname_fmt='df.columns[{}]',
the source dataframe, and values are a list of locations in the df
where they are found.
"""
self.search_nes(text)
if len(text.text) <= _df_maxlen(self.df):
for i in _text_search_array(text.text, self.df.columns):
self.results[text] = {'location': 'colname', 'tmpl': colname_fmt.format(i),
Expand All @@ -242,7 +246,6 @@ def search(self, text, colname_fmt='df.columns[{}]',
'type': 'doc'}

else:
self.search_nes(text)
for token, ix in self.search_columns(text, **kwargs).items():
ix = utils.sanitize_indices(self.df.shape, ix, 1)
self.results[token] = {'location': 'colname', 'tmpl': colname_fmt.format(ix),
Expand Down
20 changes: 12 additions & 8 deletions nlg/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class TestSearch(unittest.TestCase):
def setUpClass(cls):
fpath = op.join(op.dirname(__file__), "data", "actors.csv")
cls.df = pd.read_csv(fpath, encoding='utf-8')
fpath = op.join(op.dirname(__file__), "data", "imdb_ratings.csv")
cls.imdb = pd.read_csv(fpath, encoding='utf-8')

def test_dfsearches(self):
x = search.DFSearchResults()
Expand All @@ -118,7 +120,6 @@ def test_dfsearches(self):
x['hello'] = 'underworld'
self.assertDictEqual(x, {'hello': ['world', 'underworld']})

# @unittest.skip("Temporary")
def test_search_args(self):
args = utils.sanitize_fh_args({"_sort": ["-votes"]}, self.df)
doc = nlp("James Stewart is the top voted actor.")
Expand All @@ -145,9 +146,7 @@ def test_search_args_literal(self):
"type": "token"}})

def test_templatize(self):
fpath = op.join(op.dirname(__file__), "data", "actors.csv")
df = pd.read_csv(fpath, encoding='utf-8')
df.sort_values("votes", ascending=False, inplace=True)
df = self.df.sort_values("votes", ascending=False)
df.reset_index(inplace=True, drop=True)

doc = nlp("""
Expand Down Expand Up @@ -190,7 +189,6 @@ def test_templatize(self):
)

def test_search_sort(self):

results = [
{'tmpl': 'df.loc[0, "name"]', 'type': 'ne', 'location': 'cell'},
{'tmpl': 'df.columns[0]', 'type': 'token', 'location': 'colname'},
Expand Down Expand Up @@ -236,17 +234,23 @@ def test_single_entity_search(self):
self.assertEqual(variable.template, '{{ df["name"].iloc[0] }}')

def test_literal_search(self):
df = pd.read_csv(
op.join(op.dirname(__file__), 'data', 'imdb_ratings.csv'), encoding='utf8')
texts = ['How I Met Your Mother', 'Sherlock', 'Dexter', 'Breaking Bad']
for t in texts:
doc = nlp(t)
nugget = search.templatize(doc, {}, df)
nugget = search.templatize(doc, {}, self.imdb)
self.assertEqual(len(nugget.tokenmap), 1)
for token, variable in nugget.tokenmap.items():
self.assertEqual(token.text, t)
self.assertRegex(nugget.template, r'{{ df\["name"\].iloc\[-*\d+\] }}')

def test_search_short_strings(self):
# Check strings that are shorter than the max length of the df,
# but still not a literal match
nugget = search.templatize(nlp('Dexter is a good show'), {}, self.imdb)
self.assertEqual(len(nugget.tokenmap), 1)
token, variable = nugget.tokenmap.popitem()
self.assertRegex(variable.enabled_source['tmpl'], r'df\["name"\].iloc\[-*\d+\]')


if __name__ == "__main__":
unittest.main()
11 changes: 9 additions & 2 deletions nlg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re

import pandas as pd
from spacy.tokens import Token
from spacy.tokens import Token, Doc, Span
from tornado.template import Template

from gramex.data import filter as gfilter # NOQA: F401
Expand Down Expand Up @@ -111,12 +111,19 @@ def __call__(self, func):

def is_overlap(x, y):
"""Whether the token x is contained within any span in the sequence y."""
if len(y) == 0:
return False
if isinstance(x, Token):
if x.pos_ == "NUM":
return False
elif 'NUM' in [c.pos_ for c in x]:
return False
return any([x.text in yy.text for yy in y])
if len(y) > 1:
return any([x.text in yy.text for yy in y])
y = y.pop()
if isinstance(x, (Token, Span)) and isinstance(y, Doc):
return x.doc == y
return False


def unoverlap(tokens):
Expand Down