diff --git a/tests/integration/test_cmd.py b/tests/integration/test_cmd.py index 145b4962f..0f90da41d 100644 --- a/tests/integration/test_cmd.py +++ b/tests/integration/test_cmd.py @@ -1,6 +1,8 @@ +import prompt_toolkit import pytest from dask import config as dask_config from mock import MagicMock, patch +from packaging.version import parse as parseVersion from prompt_toolkit.application import create_app_session from prompt_toolkit.input import create_pipe_input from prompt_toolkit.output import DummyOutput @@ -8,15 +10,23 @@ from dask_sql.cmd import _meta_commands +_prompt_toolkit_version = parseVersion(prompt_toolkit.__version__) + @pytest.fixture(autouse=True, scope="function") def mock_prompt_input(): - pipe_input = create_pipe_input() - try: - with create_app_session(input=pipe_input, output=DummyOutput()): - yield pipe_input - finally: - pipe_input.close() + # TODO: remove if prompt-toolkit min version gets bumped + if _prompt_toolkit_version >= parseVersion("3.0.29"): + with create_pipe_input() as pipe_input: + with create_app_session(input=pipe_input, output=DummyOutput()): + yield pipe_input + else: + pipe_input = create_pipe_input() + try: + with create_app_session(input=pipe_input, output=DummyOutput()): + yield pipe_input + finally: + pipe_input.close() def _feed_cli_with_input(