Skip to content

Commit

Permalink
Add command line arguments for testing queries (#1161)
Browse files Browse the repository at this point in the history
* add filepath arg

* data_dir and queries_dir

---------

Co-authored-by: Ayush Dattagupta <[email protected]>
  • Loading branch information
sarahyurick and ayushdg authored Jun 20, 2023
1 parent 219f015 commit 8991706
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
12 changes: 12 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
def pytest_addoption(parser):
parser.addoption("--rungpu", action="store_true", help="run tests meant for GPU")
parser.addoption("--runqueries", action="store_true", help="run test queries")
parser.addoption("--data_dir", help="specify file path to the data")
parser.addoption("--queries_dir", help="specify file path to the queries")


def pytest_runtest_setup(item):
Expand All @@ -21,3 +23,13 @@ def pytest_runtest_setup(item):
dask.config.set({"dataframe.shuffle.algorithm": None})
if "queries" in item.keywords and not item.config.getoption("--runqueries"):
pytest.skip("need --runqueries option to run")


@pytest.fixture(scope="session")
def data_dir(request):
return request.config.getoption("--data_dir")


@pytest.fixture(scope="session")
def queries_dir(request):
return request.config.getoption("--queries_dir")
28 changes: 18 additions & 10 deletions tests/unit/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,17 @@


@pytest.fixture(scope="module")
def c():
def c(data_dir):
# Lazy import, otherwise the pytest framework has problems
from dask_sql.context import Context

c = Context()
for table_name in os.listdir(f"{os.path.dirname(__file__)}/data/"):
if not data_dir:
data_dir = f"{os.path.dirname(__file__)}/data/"
for table_name in os.listdir(data_dir):
c.create_table(
table_name,
f"{os.path.dirname(__file__)}/data/{table_name}",
data_dir + "/" + table_name,
format="parquet",
gpu=False,
)
Expand All @@ -68,17 +70,19 @@ def c():


@pytest.fixture(scope="module")
def gpu_c():
def gpu_c(data_dir):
pytest.importorskip("dask_cudf")

# Lazy import, otherwise the pytest framework has problems
from dask_sql.context import Context

c = Context()
for table_name in os.listdir(f"{os.path.dirname(__file__)}/data/"):
if not data_dir:
data_dir = f"{os.path.dirname(__file__)}/data/"
for table_name in os.listdir(data_dir):
c.create_table(
table_name,
f"{os.path.dirname(__file__)}/data/{table_name}",
data_dir + "/" + table_name,
format="parquet",
gpu=True,
)
Expand All @@ -88,8 +92,10 @@ def gpu_c():

@pytest.mark.queries
@pytest.mark.parametrize("query", QUERIES)
def test_query(c, client, query):
with open(f"{os.path.dirname(__file__)}/queries/{query}") as f:
def test_query(c, client, query, queries_dir):
if not queries_dir:
queries_dir = f"{os.path.dirname(__file__)}/queries/"
with open(queries_dir + "/" + query) as f:
sql = f.read()

res = c.sql(sql)
Expand All @@ -99,8 +105,10 @@ def test_query(c, client, query):
@pytest.mark.gpu
@pytest.mark.queries
@pytest.mark.parametrize("query", QUERIES)
def test_gpu_query(gpu_c, gpu_client, query):
with open(f"{os.path.dirname(__file__)}/queries/{query}") as f:
def test_gpu_query(gpu_c, gpu_client, query, queries_dir):
if not queries_dir:
queries_dir = f"{os.path.dirname(__file__)}/queries/"
with open(queries_dir + "/" + query) as f:
sql = f.read()

res = gpu_c.sql(sql)
Expand Down

0 comments on commit 8991706

Please sign in to comment.