Skip to content

Commit

Permalink
TST: Use more pytest fixtures (pandas-dev#53679)
Browse files Browse the repository at this point in the history
* TST: Use more pytest fixtures

* Fix sql fixture

* More test fixtures
  • Loading branch information
mroeschke authored and root committed Jun 23, 2023
1 parent fac98dd commit 525ece3
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 125 deletions.
12 changes: 0 additions & 12 deletions pandas/tests/interchange/conftest.py

This file was deleted.

70 changes: 35 additions & 35 deletions pandas/tests/interchange/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,31 @@
)
from pandas.core.interchange.from_dataframe import from_dataframe

test_data_categorical = {
"ordered": pd.Categorical(list("testdata") * 30, ordered=True),
"unordered": pd.Categorical(list("testdata") * 30, ordered=False),
}

NCOLS, NROWS = 100, 200


def _make_data(make_one):
@pytest.fixture
def data_categorical():
return {
f"col{int((i - NCOLS / 2) % NCOLS + 1)}": [make_one() for _ in range(NROWS)]
for i in range(NCOLS)
"ordered": pd.Categorical(list("testdata") * 30, ordered=True),
"unordered": pd.Categorical(list("testdata") * 30, ordered=False),
}


int_data = _make_data(lambda: random.randint(-100, 100))
uint_data = _make_data(lambda: random.randint(1, 100))
bool_data = _make_data(lambda: random.choice([True, False]))
float_data = _make_data(lambda: random.random())
datetime_data = _make_data(
lambda: datetime(
year=random.randint(1900, 2100),
month=random.randint(1, 12),
day=random.randint(1, 20),
)
)

string_data = {
"separator data": [
"abC|DeF,Hik",
"234,3245.67",
"gSaf,qWer|Gre",
"asd3,4sad|",
np.NaN,
]
}
@pytest.fixture
def string_data():
return {
"separator data": [
"abC|DeF,Hik",
"234,3245.67",
"gSaf,qWer|Gre",
"asd3,4sad|",
np.NaN,
]
}


@pytest.mark.parametrize("data", [("ordered", True), ("unordered", False)])
def test_categorical_dtype(data):
df = pd.DataFrame({"A": (test_data_categorical[data[0]])})
def test_categorical_dtype(data, data_categorical):
df = pd.DataFrame({"A": (data_categorical[data[0]])})

col = df.__dataframe__().get_column_by_name("A")
assert col.dtype[0] == DtypeKind.CATEGORICAL
Expand Down Expand Up @@ -143,9 +127,25 @@ def test_bitmasks_pyarrow(offset, length, expected_values):


@pytest.mark.parametrize(
"data", [int_data, uint_data, float_data, bool_data, datetime_data]
"data",
[
lambda: random.randint(-100, 100),
lambda: random.randint(1, 100),
lambda: random.random(),
lambda: random.choice([True, False]),
lambda: datetime(
year=random.randint(1900, 2100),
month=random.randint(1, 12),
day=random.randint(1, 20),
),
],
)
def test_dataframe(data):
NCOLS, NROWS = 10, 20
data = {
f"col{int((i - NCOLS / 2) % NCOLS + 1)}": [data() for _ in range(NROWS)]
for i in range(NCOLS)
}
df = pd.DataFrame(data)

df2 = df.__dataframe__()
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_mixed_missing():
assert df2.get_column_by_name(col_name).null_count == 2


def test_string():
def test_string(string_data):
test_str_data = string_data["separator data"] + [""]
df = pd.DataFrame({"A": test_str_data})
col = df.__dataframe__().get_column_by_name("A")
Expand Down
11 changes: 11 additions & 0 deletions pandas/tests/interchange/test_spec_conformance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@

import pytest

import pandas as pd


@pytest.fixture
def df_from_dict():
def maker(dct, is_categorical=False):
df = pd.DataFrame(dct)
return df.astype("category") if is_categorical else df

return maker


@pytest.mark.parametrize(
"test_data",
Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/io/excel/test_xlrd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@

xlrd = pytest.importorskip("xlrd")

exts = [".xls"]


@pytest.fixture(params=exts)
@pytest.fixture(params=[".xls"])
def read_ext_xlrd(request):
"""
Valid extensions for reading Excel files with xlrd.
Expand Down
12 changes: 8 additions & 4 deletions pandas/tests/io/formats/style/test_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
jinja2 = pytest.importorskip("jinja2")
from pandas.io.formats.style import Styler

loader = jinja2.PackageLoader("pandas", "io/formats/templates")
env = jinja2.Environment(loader=loader, trim_blocks=True)

@pytest.fixture
def env():
loader = jinja2.PackageLoader("pandas", "io/formats/templates")
env = jinja2.Environment(loader=loader, trim_blocks=True)
return env


@pytest.fixture
Expand All @@ -31,12 +35,12 @@ def styler_mi():


@pytest.fixture
def tpl_style():
def tpl_style(env):
return env.get_template("html_style.tpl")


@pytest.fixture
def tpl_table():
def tpl_table(env):
return env.get_template("html_table.tpl")


Expand Down
7 changes: 0 additions & 7 deletions pandas/tests/io/formats/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@

from pandas._config import config

from pandas.compat import (
IS64,
is_platform_windows,
)

import pandas as pd
from pandas import (
DataFrame,
Expand All @@ -48,8 +43,6 @@
from pandas.io.formats import printing
import pandas.io.formats.format as fmt

use_32bit_repr = is_platform_windows() or not IS64


def get_local_am_pm():
"""Return the AM and PM strings returned by strftime in current locale."""
Expand Down
9 changes: 0 additions & 9 deletions pandas/tests/io/parser/usecols/test_parse_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,6 @@
)
import pandas._testing as tm

_msg_validate_usecols_arg = (
"'usecols' must either be list-like "
"of all strings, all unicode, all "
"integers or a callable."
)
_msg_validate_usecols_names = (
"Usecols do not match columns, columns expected but not found: {0}"
)

# TODO(1.4): Change these to xfails whenever parse_dates support(which was
# intentionally disable to keep small PR sizes) is added back
pytestmark = pytest.mark.usefixtures("pyarrow_skip")
Expand Down
15 changes: 5 additions & 10 deletions pandas/tests/io/parser/usecols/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,6 @@
from pandas import DataFrame
import pandas._testing as tm

_msg_validate_usecols_arg = (
"'usecols' must either be list-like "
"of all strings, all unicode, all "
"integers or a callable."
)
_msg_validate_usecols_names = (
"Usecols do not match columns, columns expected but not found: {0}"
)


def test_usecols_with_unicode_strings(all_parsers):
# see gh-13219
Expand Down Expand Up @@ -70,7 +61,11 @@ def test_usecols_with_mixed_encoding_strings(all_parsers, usecols):
2.613230982,2,False,b
3.568935038,7,False,a"""
parser = all_parsers

_msg_validate_usecols_arg = (
"'usecols' must either be list-like "
"of all strings, all unicode, all "
"integers or a callable."
)
with pytest.raises(ValueError, match=_msg_validate_usecols_arg):
parser.read_csv(StringIO(data), usecols=usecols)

Expand Down
7 changes: 2 additions & 5 deletions pandas/tests/io/pytables/test_round_trip.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
)
from pandas.util import _test_decorators as td

_default_compressor = "blosc"


pytestmark = pytest.mark.single_cpu


Expand Down Expand Up @@ -479,7 +476,7 @@ def _make_one():
def _check_roundtrip(obj, comparator, path, compression=False, **kwargs):
options = {}
if compression:
options["complib"] = _default_compressor
options["complib"] = "blosc"

with ensure_clean_store(path, "w", **options) as store:
store["obj"] = obj
Expand All @@ -490,7 +487,7 @@ def _check_roundtrip(obj, comparator, path, compression=False, **kwargs):
def _check_roundtrip_table(obj, comparator, path, compression=False):
options = {}
if compression:
options["complib"] = _default_compressor
options["complib"] = "blosc"

with ensure_clean_store(path, "w", **options) as store:
store.put("obj", obj, format="table")
Expand Down
2 changes: 0 additions & 2 deletions pandas/tests/io/pytables/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
safe_close,
)

_default_compressor = "blosc"

from pandas.io.pytables import (
HDFStore,
read_hdf,
Expand Down
Loading

0 comments on commit 525ece3

Please sign in to comment.