Skip to content

Commit

Permalink
REFACTOR-#2739: io tests refactoring (#2740)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexander Myskov <[email protected]>
  • Loading branch information
amyskov authored Feb 16, 2021
1 parent aa818f5 commit 8ebbad9
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 330 deletions.
181 changes: 181 additions & 0 deletions modin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
import os
import sys
import pytest
import pandas
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import shutil

import modin
import modin.config
Expand All @@ -23,6 +28,17 @@
from modin.engines.python.pandas_on_python.io import PandasOnPythonIO
from modin.data_management.factories import factories
from modin.utils import get_current_backend
from modin.pandas.test.utils import (
_make_csv_file,
get_unique_filename,
teardown_test_files,
NROWS,
IO_OPS_DATA_DIR,
)

# create test data dir if it is not exists yet
if not os.path.exists(IO_OPS_DATA_DIR):
os.mkdir(IO_OPS_DATA_DIR)


def pytest_addoption(parser):
Expand Down Expand Up @@ -232,3 +248,168 @@ def pytest_runtest_call(item):
**marker.kwargs,
)
)


@pytest.fixture(scope="class")
def TestReadCSVFixture():
filenames = []
files_ids = [
"test_read_csv_regular",
"test_read_csv_blank_lines",
"test_read_csv_yes_no",
"test_read_csv_nans",
"test_read_csv_bad_lines",
]
# each xdist worker spawned in separate process with separate namespace and dataset
pytest.csvs_names = {file_id: get_unique_filename() for file_id in files_ids}
# test_read_csv_col_handling, test_read_csv_parsing
_make_csv_file(filenames)(
filename=pytest.csvs_names["test_read_csv_regular"],
)
# test_read_csv_parsing
_make_csv_file(filenames)(
filename=pytest.csvs_names["test_read_csv_yes_no"],
additional_col_values=["Yes", "true", "No", "false"],
)
# test_read_csv_col_handling
_make_csv_file(filenames)(
filename=pytest.csvs_names["test_read_csv_blank_lines"],
add_blank_lines=True,
)
# test_read_csv_nans_handling
_make_csv_file(filenames)(
filename=pytest.csvs_names["test_read_csv_nans"],
add_blank_lines=True,
additional_col_values=["<NA>", "N/A", "NA", "NULL", "custom_nan", "73"],
)
# test_read_csv_error_handling
_make_csv_file(filenames)(
filename=pytest.csvs_names["test_read_csv_bad_lines"],
add_bad_lines=True,
)

yield
# Delete csv files that were created
teardown_test_files(filenames)


@pytest.fixture
def make_csv_file():
"""Pytest fixture factory that makes temp csv files for testing.
Yields:
Function that generates csv files
"""
filenames = []

yield _make_csv_file(filenames)

# Delete csv files that were created
teardown_test_files(filenames)


@pytest.fixture
def make_parquet_file():
"""Pytest fixture factory that makes a parquet file/dir for testing.
Yields:
Function that generates a parquet file/dir
"""
filenames = []

def _make_parquet_file(
filename,
row_size=NROWS,
force=True,
directory=False,
partitioned_columns=[],
):
"""Helper function to generate parquet files/directories.
Args:
filename: The name of test file, that should be created.
row_size: Number of rows for the dataframe.
force: Create a new file/directory even if one already exists.
directory: Create a partitioned directory using pyarrow.
partitioned_columns: Create a partitioned directory using pandas.
Will be ignored if directory=True.
"""
df = pandas.DataFrame(
{"col1": np.arange(row_size), "col2": np.arange(row_size)}
)
if os.path.exists(filename) and not force:
pass
elif directory:
if os.path.exists(filename):
shutil.rmtree(filename)
else:
os.mkdir(filename)
table = pa.Table.from_pandas(df)
pq.write_to_dataset(table, root_path=filename)
elif len(partitioned_columns) > 0:
df.to_parquet(filename, partition_cols=partitioned_columns)
else:
df.to_parquet(filename)

filenames.append(filename)

# Return function that generates csv files
yield _make_parquet_file

# Delete parquet file that was created
for path in filenames:
if os.path.exists(path):
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)


@pytest.fixture
def make_sql_connection():
"""Sets up sql connections and takes them down after the caller is done.
Yields:
Factory that generates sql connection objects
"""
filenames = []

def _sql_connection(filename, table=""):
# Remove file if exists
if os.path.exists(filename):
os.remove(filename)
filenames.append(filename)
# Create connection and, if needed, table
conn = "sqlite:///{}".format(filename)
if table:
df = pandas.DataFrame(
{
"col1": [0, 1, 2, 3, 4, 5, 6],
"col2": [7, 8, 9, 10, 11, 12, 13],
"col3": [14, 15, 16, 17, 18, 19, 20],
"col4": [21, 22, 23, 24, 25, 26, 27],
"col5": [0, 0, 0, 0, 0, 0, 0],
}
)
df.to_sql(table, conn)
return conn

yield _sql_connection

# Teardown the fixture
teardown_test_files(filenames)


@pytest.fixture(scope="class")
def TestReadGlobCSVFixture():
filenames = []

base_name = get_unique_filename(extension="")
pytest.glob_path = "{}_*.csv".format(base_name)
pytest.files = ["{}_{}.csv".format(base_name, i) for i in range(11)]
for fname in pytest.files:
# Glob does not guarantee ordering so we have to remove the randomness in the generated csvs.
_make_csv_file(filenames)(fname, row_size=11, remove_randomness=True)

yield

teardown_test_files(filenames)
25 changes: 1 addition & 24 deletions modin/experimental/pandas/test/test_io_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@
import pytest
import modin.experimental.pandas as pd
from modin.config import Engine
from modin.pandas.test.test_io import ( # noqa: F401
df_equals,
eval_io,
make_sql_connection,
_make_csv_file,
teardown_test_files,
)
from modin.pandas.test.utils import get_unique_filename
from modin.pandas.test.utils import df_equals


@pytest.mark.skipif(
Expand Down Expand Up @@ -69,22 +62,6 @@ def test_from_sql_defaults(make_sql_connection): # noqa: F811
df_equals(modin_df_from_table, pandas_df)


@pytest.fixture(scope="class")
def TestReadGlobCSVFixture():
filenames = []

base_name = get_unique_filename(extension="")
pytest.glob_path = "{}_*.csv".format(base_name)
pytest.files = ["{}_{}.csv".format(base_name, i) for i in range(11)]
for fname in pytest.files:
# Glob does not guarantee ordering so we have to remove the randomness in the generated csvs.
_make_csv_file(filenames)(fname, row_size=11, remove_randomness=True)

yield

teardown_test_files(filenames)


@pytest.mark.usefixtures("TestReadGlobCSVFixture")
@pytest.mark.skipif(
Engine.get() != "Ray", reason="Currently only support Ray engine for glob paths."
Expand Down
Loading

0 comments on commit 8ebbad9

Please sign in to comment.