Skip to content

Commit

Permalink
Support stdin with asyncio (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
aklajnert authored Jan 23, 2022
1 parent dc3a0e2 commit 9e28d87
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 3 deletions.
5 changes: 5 additions & 0 deletions changelog.d/feature.0f8582b8.entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
message: Add support for stdin with asyncio.
pr_ids:
- '71'
timestamp: 1642946529
type: feature
31 changes: 28 additions & 3 deletions pytest_subprocess/fake_popen.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ def run_thread(self) -> None:
def _finish_process(self) -> None:
self.returncode = self.__returncode

self._finalize_streams()

def _finalize_streams(self) -> None:
self._finalize_stream(self.stdout)
self._finalize_stream(self.stderr)

Expand All @@ -301,13 +304,21 @@ def received_signals(self) -> Tuple[int, ...]:
class AsyncFakePopen(FakePopen):
"""Class to handle async processes"""

stdout: asyncio.StreamReader
stderr: asyncio.StreamReader
stdout: Optional[asyncio.StreamReader]
stderr: Optional[asyncio.StreamReader]

async def communicate( # type: ignore
self, input: OPTIONAL_TEXT = None, timeout: Optional[float] = None
) -> Tuple[AnyType, AnyType]:
self._handle_stdin(input)
if input:
# streams were fed with eof, need to be reopened
await self._reopen_streams()

self._handle_stdin(input)

# feed eof one more time as streams were opened
self._finalize_streams()

self._finalize_thread(timeout)

return (
Expand All @@ -320,3 +331,17 @@ async def wait(self, timeout: Optional[float] = None) -> int: # type: ignore

def _get_empty_buffer(self, _: bool) -> asyncio.StreamReader:
return asyncio.StreamReader()

async def _reopen_streams(self) -> None:
self.stdout = await self._reopen_stream(self.stdout)
self.stderr = await self._reopen_stream(self.stderr)

async def _reopen_stream(
self, stream: Optional[asyncio.StreamReader]
) -> Optional[asyncio.StreamReader]:
if stream:
data = await stream.read()
fresh_stream = self._get_empty_buffer(False)
fresh_stream.feed_data(data)
return fresh_stream
return None
36 changes: 36 additions & 0 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,42 @@ async def _read_stream(stream: asyncio.StreamReader, output_list):
output_list.append(line.decode())


@pytest.mark.asyncio
@pytest.mark.parametrize("fake", [False, True])
async def test_input(fake_process, fake):
fake_process.allow_unregistered(not fake)
if fake:

def stdin_callable(input):
return {
"stdout": "Provide an input: Provided: {data}".format(
data=input.decode()
)
}

fake_process.register_subprocess(
["python", "example_script.py", "input"],
stdout=[b"Stdout line 1", b"Stdout line 2"],
stdin_callable=stdin_callable,
)

process = await asyncio.create_subprocess_exec(
"python",
"example_script.py",
"input",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
)
out, err = await process.communicate(input=b"test")

assert out.splitlines() == [
b"Stdout line 1",
b"Stdout line 2",
b"Provide an input: Provided: test",
]
assert err is None


@pytest.fixture(autouse=True)
def skip_on_pypy():
"""Async test for some reason crash on pypy 3.6 on Windows"""
Expand Down

0 comments on commit 9e28d87

Please sign in to comment.