Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Patch Serial.in_waiting too #24

Merged
merged 6 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ dev = [
"tox",
]

[tool.coverage.report]
show_missing = true

[tool.isort]
profile = "black"

Expand Down
64 changes: 54 additions & 10 deletions src/pytest_reserial/reserial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Callable, Dict, Iterator, Literal, Tuple

import pytest
from serial import Serial # type: ignore[import-untyped]
from serial import PortNotOpenError, Serial # type: ignore[import-untyped]

TrafficLog = Dict[Literal["rx", "tx"], bytes]
PatchMethods = Tuple[
Expand All @@ -18,6 +18,7 @@
Callable[[Serial], None], # open
Callable[[Serial], None], # close
Callable[[Serial, bool], None], # _reconfigure_port
Callable[[Serial], int], # in_waiting
]


Expand Down Expand Up @@ -66,14 +67,20 @@ def reserial(
test_name = request.node.name
log = get_traffic_log(mode, log_path, test_name)

read_patch, write_patch, open_patch, close_patch, reconfigure_port_patch = (
get_patched_methods(mode, log)
)
(
read_patch,
write_patch,
open_patch,
close_patch,
reconfigure_port_patch,
in_waiting_patch,
) = get_patched_methods(mode, log)
monkeypatch.setattr(Serial, "read", read_patch)
monkeypatch.setattr(Serial, "write", write_patch)
monkeypatch.setattr(Serial, "open", open_patch)
monkeypatch.setattr(Serial, "close", close_patch)
monkeypatch.setattr(Serial, "_reconfigure_port", reconfigure_port_patch)
monkeypatch.setattr(Serial, "in_waiting", in_waiting_patch)

yield

Expand Down Expand Up @@ -136,7 +143,7 @@ def get_traffic_log(mode: Mode, log_path: Path, test_name: str) -> TrafficLog:


def get_patched_methods(mode: Mode, log: TrafficLog) -> PatchMethods:
"""Return patched read, write, open, and closed methods.
"""Return patched read, write, open, etc methods.

The methods should be monkeypatched over the corresponding `Serial` methods.

Expand All @@ -158,6 +165,10 @@ def get_patched_methods(mode: Mode, log: TrafficLog) -> PatchMethods:
Monkeypatch this over `Serial.open`.
close_patch: Callable[[Serial], None]
Monkeypatch this over `Serial.close`.
_reconfigure_port_patch: Callable[[Serial, bool], None]
Monkeypatch this over `Serial._reconfigure_port`.
in_waiting_patch: Callable[[Serial], int]
Monkeypatch this over `Serial.in_waiting`.
"""
if mode == Mode.REPLAY:
return get_replay_methods(log)
Expand All @@ -169,11 +180,12 @@ def get_patched_methods(mode: Mode, log: TrafficLog) -> PatchMethods:
Serial.open,
Serial.close,
Serial._reconfigure_port, # noqa: SLF001
Serial.in_waiting,
)


def get_replay_methods(log: TrafficLog) -> PatchMethods:
"""Return patched read, write, open, and close methods for replaying logged traffic.
"""Return patched read, write, open, etc methods for replaying logged traffic.

Parameters
----------
Expand All @@ -190,10 +202,15 @@ def get_replay_methods(log: TrafficLog) -> PatchMethods:
Sets `Serial.is_open` to `True`.
replay_close: Callable[[Serial], None]
Sets `Serial.is_open` to `False`.
replay_reconfigure_port: Callable[[Serial, bool], None]
No-op
record_in_waiting: Callable[[Serial], int]
Return the number of bytes of RX traffic left to replay.

"""

def replay_write(
self: Serial, # noqa: ARG001
self: Serial,
data: bytes,
) -> int:
"""Compare TX data to recording instead of writing to the bus.
Expand All @@ -206,6 +223,9 @@ def replay_write(
_pytest.outcomes.Failed
If written data does not match recorded data.
"""
if not self.is_open:
raise PortNotOpenError

if data == log["tx"][: len(data)]:
log["tx"] = log["tx"][len(data) :]
else:
Expand All @@ -218,19 +238,36 @@ def replay_write(
return len(data)

def replay_read(
self: Serial, # noqa: ARG001
self: Serial,
size: int = 1,
) -> bytes:
"""Replay RX data from recording instead of reading from the bus.

Monkeypatch this method over Serial.read to replay traffic. Parameters and
return values are identical to Serial.read.
"""
if not self.is_open:
raise PortNotOpenError

data = log["rx"][:size]
log["rx"] = log["rx"][size:]
return bytes(data)

return replay_read, replay_write, replay_open, replay_close, replay_reconfigure_port
@property # type: ignore[misc]
def replay_in_waiting(
self: Serial, # noqa:ARG001
) -> int:
"""Return the number of bytes in RX data left to replay."""
return len(log["rx"])

return (
replay_read,
replay_write,
replay_open,
replay_close,
replay_reconfigure_port,
replay_in_waiting,
)


# The open/close method patches don't need access to logs, so they can stay down here.
Expand Down Expand Up @@ -258,7 +295,7 @@ def replay_reconfigure_port(


def get_record_methods(log: TrafficLog) -> PatchMethods:
"""Return patched read, write, open, and close methods for recording traffic.
"""Return patched read, write, open, etc methods for recording traffic.

Parameters
----------
Expand All @@ -275,6 +312,12 @@ def get_record_methods(log: TrafficLog) -> PatchMethods:
Does not need to be patched when recording, so this is `Serial.open`.
record_close: Callable[[Serial], None]
Does not need to be patched when recording, so this is `Serial.close`.
record_reconfigure_port: Callable[[Serial, bool], None]
Does not need to be patched when recording,
so this is `Serial._reconfigure_port`.
record_in_waiting: Callable[[Serial], int]
Does not need to be patched when recording, so this is `Serial.in_waiting`.

"""
real_read = Serial.read
real_write = Serial.write
Expand Down Expand Up @@ -305,6 +348,7 @@ def record_read(self: Serial, size: int = 1) -> bytes:
Serial.open,
Serial.close,
Serial._reconfigure_port, # noqa: SLF001
Serial.in_waiting,
)


Expand Down
28 changes: 27 additions & 1 deletion tests/test_reserial.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,33 @@
def test_reserial(reserial):
s = serial.Serial(port="/dev/ttyUSB0")
s.write({TEST_TX!r})
assert s.in_waiting == {len(TEST_RX)}
assert s.read() == {TEST_RX!r}
def test_reserial2(reserial):
s = serial.Serial(port="/dev/ttyUSB0")
s.write({TEST_TX!r})
assert s.read() == {TEST_RX!r}
"""
TEST_FILE_REPLAY = f"""
import pytest
import serial
def test_reserial(reserial):
s = serial.Serial(port="/dev/ttyUSB0")
s.write({TEST_TX!r})
assert s.in_waiting == {len(TEST_RX)}
assert s.read() == {TEST_RX!r}
assert s.in_waiting == 0
s.close()
with pytest.raises(serial.PortNotOpenError):
s.read()
def test_reserial2(reserial):
s = serial.Serial(port="/dev/ttyUSB0")
s.write({TEST_TX!r})
assert s.read() == {TEST_RX!r}
s.close()
with pytest.raises(serial.PortNotOpenError):
s.write({TEST_TX!r})
"""
TEST_FILE_BAD_TX = f"""
import serial
def test_reserial(reserial):
Expand All @@ -39,6 +60,10 @@ def patch_write(self: Serial, data: bytes) -> int:
def patch_read(self: Serial, size: int = 1) -> bytes:
return TEST_RX

@property
def patch_in_waiting(self: Serial) -> int:
return len(TEST_RX)

def patch_open(self: Serial) -> None:
self.is_open = True

Expand All @@ -49,6 +74,7 @@ def patch_close(self: Serial) -> None:
monkeypatch.setattr(Serial, "read", patch_read)
monkeypatch.setattr(Serial, "open", patch_open)
monkeypatch.setattr(Serial, "close", patch_close)
monkeypatch.setattr(Serial, "in_waiting", patch_in_waiting)
result = pytester.runpytest("--record")

with open("test_record.jsonl", "r") as f:
Expand Down Expand Up @@ -108,7 +134,7 @@ def patch_close(self: Serial) -> None:

def test_replay(pytester):
pytester.makefile(".jsonl", test_replay=TEST_JSONL)
pytester.makepyfile(TEST_FILE)
pytester.makepyfile(TEST_FILE_REPLAY)
result = pytester.runpytest("--replay")
assert result.ret == 0

Expand Down