Skip to content

Commit

Permalink
test: use fixtures (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbelak-dtml authored Mar 11, 2024
1 parent eb78c67 commit 9032158
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 173 deletions.
6 changes: 2 additions & 4 deletions tests/pyarrow_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest

from edvart.data_types import PYARROW_PANDAS_BACKEND_AVAILABLE

if PYARROW_PANDAS_BACKEND_AVAILABLE:
pyarrow_parameterize = pytest.mark.parametrize("pyarrow_dtypes", [False, True])
pyarrow_params = [True, False]
else:
pyarrow_parameterize = pytest.mark.parametrize("pyarrow_dtypes", [False])
pyarrow_params = [False]
32 changes: 16 additions & 16 deletions tests/test_bivariate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from edvart.report_sections.section_base import Verbosity

from .execution_utils import check_section_executes
from .pyarrow_utils import pyarrow_parameterize
from .pyarrow_utils import pyarrow_params


def get_test_df(pyarrow_dtypes: bool = False) -> pd.DataFrame:
@pytest.fixture(params=pyarrow_params)
def test_df(request) -> pd.DataFrame:
test_df = pd.DataFrame(data=[[1.1, "a"], [2.2, "b"], [3.3, "c"]], columns=["A", "B"])
if pyarrow_dtypes:
if request.param:
test_df = test_df.convert_dtypes(dtype_backend="pyarrow")

return test_df
Expand Down Expand Up @@ -125,7 +126,7 @@ def test_section_adding():
), "Subsection should be ContingencyTable"


def test_code_export_verbosity_low():
def test_code_export_verbosity_low(test_df: pd.DataFrame):
bivariate_section = bivariate_analysis.BivariateAnalysis(verbosity=Verbosity.LOW)
# Export code
exported_cells = []
Expand All @@ -138,10 +139,10 @@ def test_code_export_verbosity_low():
assert len(exported_code) == 1
assert exported_code[0] == expected_code[0], "Exported code mismatch"

check_section_executes(bivariate_section, df=get_test_df())
check_section_executes(bivariate_section, df=test_df)


def test_code_export_verbosity_low_with_subsections():
def test_code_export_verbosity_low_with_subsections(test_df: pd.DataFrame):
bivariate_section = bivariate_analysis.BivariateAnalysis(
subsections=[
BivariateAnalysisSubsection.ContingencyTable,
Expand All @@ -164,7 +165,7 @@ def test_code_export_verbosity_low_with_subsections():
assert len(exported_code) == 1
assert exported_code[0] == expected_code[0], "Exported code mismatch"

check_section_executes(bivariate_section, df=get_test_df())
check_section_executes(bivariate_section, df=test_df)


def test_generated_code_verbosity_low_columns():
Expand Down Expand Up @@ -209,7 +210,7 @@ def test_generated_code_verbosity_low_columns():
check_section_executes(bivariate_section, df=test_df)


def test_generated_code_verbosity_medium():
def test_generated_code_verbosity_medium(test_df: pd.DataFrame):
bivariate_section = bivariate_analysis.BivariateAnalysis(
verbosity=Verbosity.MEDIUM,
subsections=[
Expand All @@ -233,7 +234,7 @@ def test_generated_code_verbosity_medium():
for expected_line, exported_line in zip(expected_code, exported_code):
assert expected_line == exported_line, "Exported code mismatch"

check_section_executes(bivariate_section, df=get_test_df())
check_section_executes(bivariate_section, df=test_df)


def test_generated_code_verbosity_medium_columns_x_y():
Expand Down Expand Up @@ -307,7 +308,7 @@ def test_generated_code_verbosity_medium_columns_pairs():
check_section_executes(bivariate_section, df=test_df)


def test_generated_code_verbosity_high():
def test_generated_code_verbosity_high(test_df: pd.DataFrame):
bivariate_section = bivariate_analysis.BivariateAnalysis(
verbosity=Verbosity.HIGH,
subsections=[
Expand Down Expand Up @@ -345,10 +346,10 @@ def test_generated_code_verbosity_high():
for expected_line, exported_line in zip(expected_code, exported_code):
assert expected_line == exported_line, "Exported code mismatch"

check_section_executes(bivariate_section, df=get_test_df())
check_section_executes(bivariate_section, df=test_df)


def test_verbosity_low_different_subsection_verbosities():
def test_verbosity_low_different_subsection_verbosities(test_df: pd.DataFrame):
bivariate_section = BivariateAnalysis(
verbosity=Verbosity.LOW,
subsections=[
Expand Down Expand Up @@ -377,7 +378,7 @@ def test_verbosity_low_different_subsection_verbosities():
for expected_line, exported_line in zip(expected_code, exported_code):
assert expected_line == exported_line, "Exported code mismatch"

check_section_executes(bivariate_section, df=get_test_df())
check_section_executes(bivariate_section, df=test_df)


def test_imports_verbosity_low():
Expand Down Expand Up @@ -449,10 +450,9 @@ def test_imports_verbosity_low_different_subsection_verbosities():
assert set(exported_imports) == set(expected_imports)


@pyarrow_parameterize
def test_show(pyarrow_dtypes: bool):
def test_show(test_df: pd.DataFrame):
bivariate_section = BivariateAnalysis()
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
with redirect_stdout(None):
bivariate_section.show(get_test_df(pyarrow_dtypes=pyarrow_dtypes))
bivariate_section.show(test_df)
104 changes: 45 additions & 59 deletions tests/test_group_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,22 @@
from edvart.report_sections.section_base import Verbosity

from .execution_utils import check_section_executes
from .pyarrow_utils import pyarrow_parameterize
from .pyarrow_utils import pyarrow_params

# Workaround to prevent multiple browser tabs opening with figures
plotly.io.renderers.default = "json"


def get_test_df(pyarrow_dtypes: bool = False) -> pd.DataFrame:
@pytest.fixture(params=pyarrow_params)
def test_df(request) -> pd.DataFrame:
test_df = pd.DataFrame(
data=[
["P" if np.random.uniform() < 0.4 else "N", 1.5 * i, "X" if i % 2 == 0 else "Y"]
for i in range(60)
],
columns=["A", "B", "C"],
)
if pyarrow_dtypes:
if request.param:
test_df = test_df.convert_dtypes(dtype_backend="pyarrow")
return test_df

Expand All @@ -53,51 +54,44 @@ def test_invalid_verbosities():
GroupAnalysis(groupby=[], verbosity=-1)


@pyarrow_parameterize
def test_groupby_nonexistent_col(pyarrow_dtypes: bool):
def test_groupby_nonexistent_col(test_df: pd.DataFrame):
with pytest.raises(ValueError):
show_group_analysis(df=get_test_df(pyarrow_dtypes=pyarrow_dtypes), groupby=["non-existent"])
show_group_analysis(df=test_df, groupby=["non-existent"])
with pytest.raises(ValueError):
group_missing_values(
df=get_test_df(pyarrow_dtypes=pyarrow_dtypes), groupby=["non-existent"]
)
group_missing_values(df=test_df, groupby=["non-existent"])


@pyarrow_parameterize
def test_static_methods(pyarrow_dtypes: bool):
df = get_test_df(pyarrow_dtypes=pyarrow_dtypes)
def test_static_methods(test_df: pd.DataFrame):
with redirect_stdout(None):
show_group_analysis(df=df, groupby="C")
show_group_analysis(df=df, groupby=["C"], columns=["A"])
show_group_analysis(df=df, groupby=["C"], columns=["A", "B"])
show_group_analysis(df=df, groupby="C", columns=["A", "B", "C"])
show_group_analysis(df=df, groupby="C", columns=["C"])

group_barplot(df, groupby=["A"], column="B")
group_barplot(df, groupby=["A"], column="A")
group_barplot(df, groupby=["A", "C"], column="B")
group_barplot(df, groupby=["A"], column="C")
group_barplot(df, groupby=["A"], column="C")

group_missing_values(df, groupby=["C"])
group_missing_values(df, groupby=["C"], columns=["A", "B"])
group_missing_values(df, groupby=["C"], columns=["A", "B", "C"])
group_missing_values(df, groupby=["C"], columns=["C"])

overlaid_histograms(df, groupby=["A"], column="B")
overlaid_histograms(df, groupby=["A", "C"], column="B")
overlaid_histograms(df, groupby=["A", "C"], column="B")
overlaid_histograms(df, groupby=["B"], column="B")


@pyarrow_parameterize
def test_code_export_verbosity_low(pyarrow_dtypes: bool):
df = get_test_df(pyarrow_dtypes=pyarrow_dtypes)
show_group_analysis(df=test_df, groupby="C")
show_group_analysis(df=test_df, groupby=["C"], columns=["A"])
show_group_analysis(df=test_df, groupby=["C"], columns=["A", "B"])
show_group_analysis(df=test_df, groupby="C", columns=["A", "B", "C"])
show_group_analysis(df=test_df, groupby="C", columns=["C"])

group_barplot(test_df, groupby=["A"], column="B")
group_barplot(test_df, groupby=["A"], column="A")
group_barplot(test_df, groupby=["A", "C"], column="B")
group_barplot(test_df, groupby=["A"], column="C")
group_barplot(test_df, groupby=["A"], column="C")

group_missing_values(test_df, groupby=["C"])
group_missing_values(test_df, groupby=["C"], columns=["A", "B"])
group_missing_values(test_df, groupby=["C"], columns=["A", "B", "C"])
group_missing_values(test_df, groupby=["C"], columns=["C"])

overlaid_histograms(test_df, groupby=["A"], column="B")
overlaid_histograms(test_df, groupby=["A", "C"], column="B")
overlaid_histograms(test_df, groupby=["A", "C"], column="B")
overlaid_histograms(test_df, groupby=["B"], column="B")


def test_code_export_verbosity_low(test_df: pd.DataFrame):
group_section = GroupAnalysis(groupby="B", verbosity=Verbosity.LOW)

# Export code
exported_cells = []
group_section.add_cells(exported_cells, df=df)
group_section.add_cells(exported_cells, df=test_df)
# Remove markdown and other cells and get code strings
exported_code = [cell["source"] for cell in exported_cells if cell["cell_type"] == "code"]
# Define expected code
Expand All @@ -106,17 +100,15 @@ def test_code_export_verbosity_low(pyarrow_dtypes: bool):
assert len(exported_code) == 1
assert exported_code[0] == expected_code[0], "Exported code mismatch"

check_section_executes(group_section, df)
check_section_executes(group_section, test_df)


@pyarrow_parameterize
def test_code_export_verbosity_medium(pyarrow_dtypes: bool):
df = get_test_df(pyarrow_dtypes=pyarrow_dtypes)
def test_code_export_verbosity_medium(test_df: pd.DataFrame):
group_section = GroupAnalysis(groupby="A", verbosity=Verbosity.MEDIUM)

# Export code
exported_cells = []
group_section.add_cells(exported_cells, df=df)
group_section.add_cells(exported_cells, df=test_df)
# Remove markdown and other cells and get code strings
exported_code = [cell["source"] for cell in exported_cells if cell["cell_type"] == "code"]
# Define expected code
Expand All @@ -135,17 +127,15 @@ def test_code_export_verbosity_medium(pyarrow_dtypes: bool):
for expected_line, exported_line in zip(expected_code, exported_code):
assert expected_line == exported_line, "Exported code mismatch"

check_section_executes(group_section, df)
check_section_executes(group_section, test_df)


@pyarrow_parameterize
def test_code_export_verbosity_high(pyarrow_dtypes: bool):
df = get_test_df(pyarrow_dtypes=pyarrow_dtypes)
def test_code_export_verbosity_high(test_df: pd.DataFrame):
group_section = GroupAnalysis(groupby="A", verbosity=Verbosity.HIGH)

# Export code
exported_cells = []
group_section.add_cells(exported_cells, df=df)
group_section.add_cells(exported_cells, df=test_df)
# Remove markdown and other cells and get code strings
exported_code = [cell["source"] for cell in exported_cells if cell["cell_type"] == "code"]
# Define expected code
Expand Down Expand Up @@ -192,21 +182,19 @@ def test_code_export_verbosity_high(pyarrow_dtypes: bool):
for expected_line, exported_line in zip(expected_code, exported_code):
assert expected_line == exported_line, "Exported code mismatch"

check_section_executes(group_section, df)
check_section_executes(group_section, test_df)


@pyarrow_parameterize
def test_columns_parameter(pyarrow_dtypes: bool):
df = get_test_df(pyarrow_dtypes=pyarrow_dtypes)
def test_columns_parameter(test_df: pd.DataFrame):
ga = GroupAnalysis(groupby="A", columns=["B"])
assert ga.groupby == ["A"]
assert ga.columns == ["B"]

ga = GroupAnalysis(groupby="A")
assert ga.groupby == ["A"]
assert ga.columns is None
ga.show(df)
ga.add_cells([], df=df)
ga.show(test_df)
ga.add_cells([], df=test_df)
assert ga.groupby == ["A"]
assert ga.columns is None

Expand All @@ -217,11 +205,9 @@ def test_column_list_not_modified():
assert columns == ["C"], "Column list modified"


@pyarrow_parameterize
def test_show(pyarrow_dtypes: bool):
df = get_test_df(pyarrow_dtypes=pyarrow_dtypes)
def test_show(test_df: pd.DataFrame):
group_section = GroupAnalysis(groupby="A")
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
with redirect_stdout(None):
group_section.show(df)
group_section.show(test_df)
Loading

0 comments on commit 9032158

Please sign in to comment.