Skip to content

Commit

Permalink
TST: Replace yield-based tests in test_query_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger authored and gfyoung committed May 4, 2017
1 parent 0fb4854 commit 7bed86a
Showing 1 changed file with 28 additions and 78 deletions.
106 changes: 28 additions & 78 deletions pandas/tests/frame/test_query_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import operator
import pytest
from itertools import product

from pandas.compat import (zip, range, lrange, StringIO)
from pandas import DataFrame, Series, Index, MultiIndex, date_range
Expand All @@ -27,6 +26,16 @@
ENGINES = 'python', 'numexpr'


@pytest.fixture(params=PARSERS, ids=lambda x: x)
def parser(request):
return request.param


@pytest.fixture(params=ENGINES, ids=lambda x: x)
def engine(request):
return request.param


def skip_if_no_pandas_parser(parser):
if parser != 'pandas':
pytest.skip("cannot evaluate with parser {0!r}".format(parser))
Expand Down Expand Up @@ -165,8 +174,9 @@ def test_eval_resolvers_as_list(self):

class TestDataFrameQueryWithMultiIndex(tm.TestCase):

def check_query_with_named_multiindex(self, parser, engine):
def test_query_with_named_multiindex(self, parser, engine):
tm.skip_if_no_ne(engine)
skip_if_no_pandas_parser(parser)
a = np.random.choice(['red', 'green'], size=10)
b = np.random.choice(['eggs', 'ham'], size=10)
index = MultiIndex.from_arrays([a, b], names=['color', 'food'])
Expand Down Expand Up @@ -214,12 +224,9 @@ def check_query_with_named_multiindex(self, parser, engine):
assert_frame_equal(res1, exp)
assert_frame_equal(res2, exp)

def test_query_with_named_multiindex(self):
for parser, engine in product(['pandas'], ENGINES):
yield self.check_query_with_named_multiindex, parser, engine

def check_query_with_unnamed_multiindex(self, parser, engine):
def test_query_with_unnamed_multiindex(self, parser, engine):
tm.skip_if_no_ne(engine)
skip_if_no_pandas_parser(parser)
a = np.random.choice(['red', 'green'], size=10)
b = np.random.choice(['eggs', 'ham'], size=10)
index = MultiIndex.from_arrays([a, b])
Expand Down Expand Up @@ -308,12 +315,9 @@ def check_query_with_unnamed_multiindex(self, parser, engine):
assert_frame_equal(res1, exp)
assert_frame_equal(res2, exp)

def test_query_with_unnamed_multiindex(self):
for parser, engine in product(['pandas'], ENGINES):
yield self.check_query_with_unnamed_multiindex, parser, engine

def check_query_with_partially_named_multiindex(self, parser, engine):
def test_query_with_partially_named_multiindex(self, parser, engine):
tm.skip_if_no_ne(engine)
skip_if_no_pandas_parser(parser)
a = np.random.choice(['red', 'green'], size=10)
b = np.arange(10)
index = MultiIndex.from_arrays([a, b])
Expand Down Expand Up @@ -341,17 +345,7 @@ def check_query_with_partially_named_multiindex(self, parser, engine):
exp = df[ind != "red"]
assert_frame_equal(res, exp)

def test_query_with_partially_named_multiindex(self):
for parser, engine in product(['pandas'], ENGINES):
yield (self.check_query_with_partially_named_multiindex,
parser, engine)

def test_query_multiindex_get_index_resolvers(self):
for parser, engine in product(['pandas'], ENGINES):
yield (self.check_query_multiindex_get_index_resolvers, parser,
engine)

def check_query_multiindex_get_index_resolvers(self, parser, engine):
df = mkdf(10, 3, r_idx_nlevels=2, r_idx_names=['spam', 'eggs'])
resolvers = df._get_index_resolvers()

Expand All @@ -375,22 +369,14 @@ def to_series(mi, level):
else:
raise AssertionError("object must be a Series or Index")

def test_raise_on_panel_with_multiindex(self):
for parser, engine in product(PARSERS, ENGINES):
yield self.check_raise_on_panel_with_multiindex, parser, engine

def check_raise_on_panel_with_multiindex(self, parser, engine):
def test_raise_on_panel_with_multiindex(self, parser, engine):
tm.skip_if_no_ne()
p = tm.makePanel(7)
p.items = tm.makeCustomIndex(len(p.items), nlevels=2)
with pytest.raises(NotImplementedError):
pd.eval('p + 1', parser=parser, engine=engine)

def test_raise_on_panel4d_with_multiindex(self):
for parser, engine in product(PARSERS, ENGINES):
yield self.check_raise_on_panel4d_with_multiindex, parser, engine

def check_raise_on_panel4d_with_multiindex(self, parser, engine):
def test_raise_on_panel4d_with_multiindex(self, parser, engine):
tm.skip_if_no_ne()
p4d = tm.makePanel4D(7)
p4d.items = tm.makeCustomIndex(len(p4d.items), nlevels=2)
Expand Down Expand Up @@ -874,7 +860,7 @@ def test_query_builtin(self):

class TestDataFrameQueryStrings(tm.TestCase):

def check_str_query_method(self, parser, engine):
def test_str_query_method(self, parser, engine):
tm.skip_if_no_ne(engine)
df = DataFrame(randn(10, 1), columns=['b'])
df['strings'] = Series(list('aabbccddee'))
Expand Down Expand Up @@ -911,15 +897,7 @@ def check_str_query_method(self, parser, engine):
assert_frame_equal(res, expect)
assert_frame_equal(res, df[~df.strings.isin(['a'])])

def test_str_query_method(self):
for parser, engine in product(PARSERS, ENGINES):
yield self.check_str_query_method, parser, engine

def test_str_list_query_method(self):
for parser, engine in product(PARSERS, ENGINES):
yield self.check_str_list_query_method, parser, engine

def check_str_list_query_method(self, parser, engine):
def test_str_list_query_method(self, parser, engine):
tm.skip_if_no_ne(engine)
df = DataFrame(randn(10, 1), columns=['b'])
df['strings'] = Series(list('aabbccddee'))
Expand Down Expand Up @@ -958,7 +936,7 @@ def check_str_list_query_method(self, parser, engine):
parser=parser)
assert_frame_equal(res, expect)

def check_query_with_string_columns(self, parser, engine):
def test_query_with_string_columns(self, parser, engine):
tm.skip_if_no_ne(engine)
df = DataFrame({'a': list('aaaabbbbcccc'),
'b': list('aabbccddeeff'),
Expand All @@ -979,11 +957,7 @@ def check_query_with_string_columns(self, parser, engine):
with pytest.raises(NotImplementedError):
df.query('a in b and c < d', parser=parser, engine=engine)

def test_query_with_string_columns(self):
for parser, engine in product(PARSERS, ENGINES):
yield self.check_query_with_string_columns, parser, engine

def check_object_array_eq_ne(self, parser, engine):
def test_object_array_eq_ne(self, parser, engine):
tm.skip_if_no_ne(engine)
df = DataFrame({'a': list('aaaabbbbcccc'),
'b': list('aabbccddeeff'),
Expand All @@ -997,11 +971,7 @@ def check_object_array_eq_ne(self, parser, engine):
exp = df[df.a != df.b]
assert_frame_equal(res, exp)

def test_object_array_eq_ne(self):
for parser, engine in product(PARSERS, ENGINES):
yield self.check_object_array_eq_ne, parser, engine

def check_query_with_nested_strings(self, parser, engine):
def test_query_with_nested_strings(self, parser, engine):
tm.skip_if_no_ne(engine)
skip_if_no_pandas_parser(parser)
raw = """id event timestamp
Expand All @@ -1025,11 +995,7 @@ def check_query_with_nested_strings(self, parser, engine):
engine=engine)
assert_frame_equal(expected, res)

def test_query_with_nested_string(self):
for parser, engine in product(PARSERS, ENGINES):
yield self.check_query_with_nested_strings, parser, engine

def check_query_with_nested_special_character(self, parser, engine):
def test_query_with_nested_special_character(self, parser, engine):
skip_if_no_pandas_parser(parser)
tm.skip_if_no_ne(engine)
df = DataFrame({'a': ['a', 'b', 'test & test'],
Expand All @@ -1038,12 +1004,7 @@ def check_query_with_nested_special_character(self, parser, engine):
expec = df[df.a == 'test & test']
assert_frame_equal(res, expec)

def test_query_with_nested_special_character(self):
for parser, engine in product(PARSERS, ENGINES):
yield (self.check_query_with_nested_special_character,
parser, engine)

def check_query_lex_compare_strings(self, parser, engine):
def test_query_lex_compare_strings(self, parser, engine):
tm.skip_if_no_ne(engine=engine)
import operator as opr

Expand All @@ -1058,11 +1019,7 @@ def check_query_lex_compare_strings(self, parser, engine):
expected = df[func(df.X, 'd')]
assert_frame_equal(res, expected)

def test_query_lex_compare_strings(self):
for parser, engine in product(PARSERS, ENGINES):
yield self.check_query_lex_compare_strings, parser, engine

def check_query_single_element_booleans(self, parser, engine):
def test_query_single_element_booleans(self, parser, engine):
tm.skip_if_no_ne(engine)
columns = 'bid', 'bidsize', 'ask', 'asksize'
data = np.random.randint(2, size=(1, len(columns))).astype(bool)
Expand All @@ -1071,23 +1028,16 @@ def check_query_single_element_booleans(self, parser, engine):
expected = df[df.bid & df.ask]
assert_frame_equal(res, expected)

def test_query_single_element_booleans(self):
for parser, engine in product(PARSERS, ENGINES):
yield self.check_query_single_element_booleans, parser, engine

def check_query_string_scalar_variable(self, parser, engine):
def test_query_string_scalar_variable(self, parser, engine):
tm.skip_if_no_ne(engine)
skip_if_no_pandas_parser(parser)
df = pd.DataFrame({'Symbol': ['BUD US', 'BUD US', 'IBM US', 'IBM US'],
'Price': [109.70, 109.72, 183.30, 183.35]})
e = df[df.Symbol == 'BUD US']
symb = 'BUD US' # noqa
r = df.query('Symbol == @symb', parser=parser, engine=engine)
assert_frame_equal(e, r)

def test_query_string_scalar_variable(self):
for parser, engine in product(['pandas'], ENGINES):
yield self.check_query_string_scalar_variable, parser, engine


class TestDataFrameEvalNumExprPandas(tm.TestCase):

Expand Down

0 comments on commit 7bed86a

Please sign in to comment.