diff --git a/api/src/opentrons/execute.py b/api/src/opentrons/execute.py index 63fe43cd2ca..75add87a4d6 100644 --- a/api/src/opentrons/execute.py +++ b/api/src/opentrons/execute.py @@ -5,60 +5,106 @@ regular python shells. It also provides a console entrypoint for running a protocol from the command line. """ +import asyncio import atexit import argparse +import contextlib import logging import os +from pathlib import Path import sys +import tempfile from typing import ( TYPE_CHECKING, BinaryIO, Callable, Dict, + Generator, List, Optional, TextIO, Union, ) +from opentrons_shared_data.labware.labware_definition import LabwareDefinition +from opentrons_shared_data.robot.dev_types import RobotType + from opentrons import protocol_api, __version__, should_use_ot3 -from opentrons.config import IS_ROBOT, JUPYTER_NOTEBOOK_LABWARE_DIR -from opentrons.protocols.execution import execute as execute_apiv2 from opentrons.commands import types as command_types + +from opentrons.config import IS_ROBOT, JUPYTER_NOTEBOOK_LABWARE_DIR + +from opentrons.hardware_control import ( + API as OT2API, + HardwareControlAPI, + ThreadManagedHardware, + ThreadManager, +) + from opentrons.protocols import parse -from opentrons.protocols.types import ApiDeprecationError from opentrons.protocols.api_support.deck_type import ( guess_from_global_config as guess_deck_type_from_global_config, ) from opentrons.protocols.api_support.types import APIVersion -from opentrons.hardware_control import ( - API as OT2API, - ThreadManagedHardware, - ThreadManager, +from opentrons.protocols.execution import execute as execute_apiv2 +from opentrons.protocols.types import ( + ApiDeprecationError, + Protocol, + PythonProtocol, ) -from opentrons_shared_data.robot.dev_types import RobotType -from .util.entrypoint_util import labware_from_paths, datafiles_from_paths +from opentrons.protocol_api.core.engine import ENGINE_CORE_API_VERSION +from opentrons.protocol_api.protocol_context import ProtocolContext + +from opentrons.protocol_engine import ( + Config, + DeckType, + EngineStatus, + ErrorOccurrence as ProtocolEngineErrorOccurrence, + create_protocol_engine, + create_protocol_engine_in_thread, +) + +from opentrons.protocol_reader import ProtocolReader, ProtocolSource + +from opentrons.protocol_runner import create_protocol_runner + +from .util.entrypoint_util import ( + FoundLabware, + labware_from_paths, + datafiles_from_paths, + copy_file_like, +) if TYPE_CHECKING: - from opentrons_shared_data.labware.dev_types import LabwareDefinition + from opentrons_shared_data.labware.dev_types import ( + LabwareDefinition as LabwareDefinitionDict, + ) + _THREAD_MANAGED_HW: Optional[ThreadManagedHardware] = None #: The background global cache that all protocol contexts created by #: :py:meth:`get_protocol_api` will share +# When a ProtocolContext is using a ProtocolEngine to control the robot, it requires some +# additional long-lived resources besides _THREAD_MANAGED_HARDWARE. There's a background thread, +# an asyncio event loop in that thread, and some ProtocolEngine-controlled background tasks in that +# event loop. +# +# When we're executing a protocol file beginning-to-end, we can clean up those resources after it +# completes. However, when someone gets a live ProtocolContext through get_protocol_api(), we have +# no way of knowing when they're done with it. So, as a hack, we keep these resources open +# indefinitely, letting them leak. +# +# We keep this at module scope so that the contained context managers aren't garbage-collected. +# If they're garbage collected, they can close their resources prematurely. +# https://stackoverflow.com/a/69155026/497934 +_LIVE_PROTOCOL_ENGINE_CONTEXTS = contextlib.ExitStack() + + # See Jira RCORE-535. -_PYTHON_TOO_NEW_MESSAGE = ( - "Python protocols with apiLevels higher than 2.13" - " cannot currently be executed with" - " the opentrons_execute command-line tool," - " the opentrons.execute.execute() function," - " or the opentrons.execute.get_protocol_api() function." - " Use a lower apiLevel" - " or use the Opentrons App instead." -) _JSON_TOO_NEW_MESSAGE = ( "Protocols created by recent versions of Protocol Designer" " cannot currently be executed with" @@ -68,11 +114,14 @@ ) +_EmitRunlogCallable = Callable[[command_types.CommandMessage], None] + + def get_protocol_api( version: Union[str, APIVersion], - bundled_labware: Optional[Dict[str, "LabwareDefinition"]] = None, + bundled_labware: Optional[Dict[str, "LabwareDefinitionDict"]] = None, bundled_data: Optional[Dict[str, bytes]] = None, - extra_labware: Optional[Dict[str, "LabwareDefinition"]] = None, + extra_labware: Optional[Dict[str, "LabwareDefinitionDict"]] = None, ) -> protocol_api.ProtocolContext: """ Build and return a ``protocol_api.ProtocolContext`` @@ -117,16 +166,9 @@ def get_protocol_api( else: checked_version = version - if ( - extra_labware is None - and IS_ROBOT - and JUPYTER_NOTEBOOK_LABWARE_DIR.is_dir() # type: ignore[union-attr] - ): + if extra_labware is None: extra_labware = { - uri: details.definition - for uri, details in labware_from_paths( - [str(JUPYTER_NOTEBOOK_LABWARE_DIR)] - ).items() + uri: details.definition for uri, details in _get_jupyter_labware().items() } robot_type = _get_robot_type() @@ -134,8 +176,8 @@ def get_protocol_api( hardware_controller = _get_global_hardware_controller(robot_type) - try: - context = protocol_api.create_protocol_context( + if checked_version < ENGINE_CORE_API_VERSION: + context = _create_live_context_non_pe( api_version=checked_version, deck_type=deck_type, hardware_api=hardware_controller, @@ -143,8 +185,20 @@ def get_protocol_api( bundled_data=bundled_data, extra_labware=extra_labware, ) - except protocol_api.ProtocolEngineCoreRequiredError as e: - raise NotImplementedError(_PYTHON_TOO_NEW_MESSAGE) from e # See Jira RCORE-535. + else: + if bundled_labware is not None: + raise NotImplementedError( + f"The bundled_labware argument is not currently supported for Python protocols" + f" with apiLevel {ENGINE_CORE_API_VERSION} or newer." + ) + context = _create_live_context_pe( + api_version=checked_version, + robot_type=robot_type, + deck_type=guess_deck_type_from_global_config(), + hardware_api=_THREAD_MANAGED_HW, # type: ignore[arg-type] + bundled_data=bundled_data, + extra_labware=extra_labware, + ) hardware_controller.sync.cache_instruments() return context @@ -229,12 +283,12 @@ def get_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: return parser -def execute( +def execute( # noqa: C901 protocol_file: Union[BinaryIO, TextIO], protocol_name: str, propagate_logs: bool = False, log_level: str = "warning", - emit_runlog: Optional[Callable[[command_types.CommandMessage], None]] = None, + emit_runlog: Optional[_EmitRunlogCallable] = None, custom_labware_paths: Optional[List[str]] = None, custom_data_paths: Optional[List[str]] = None, ) -> None: @@ -300,20 +354,18 @@ def execute( # will produce a string with information filled in } } - - """ stack_logger = logging.getLogger("opentrons") stack_logger.propagate = propagate_logs stack_logger.setLevel(getattr(logging, log_level.upper(), logging.WARNING)) + contents = protocol_file.read() + if custom_labware_paths: - extra_labware = { - uri: details.definition - for uri, details in labware_from_paths(custom_labware_paths).items() - } + extra_labware = labware_from_paths(custom_labware_paths) else: - extra_labware = {} + extra_labware = _get_jupyter_labware() + if custom_data_paths: extra_data = datafiles_from_paths(custom_data_paths) else: @@ -321,7 +373,12 @@ def execute( try: protocol = parse.parse( - contents, protocol_name, extra_labware=extra_labware, extra_data=extra_data + contents, + protocol_name, + extra_labware={ + uri: details.definition for uri, details in extra_labware.items() + }, + extra_data=extra_data, ) except parse.JSONSchemaVersionTooNewError as e: if e.attempted_schema_version == 6: @@ -332,24 +389,42 @@ def execute( if protocol.api_level < APIVersion(2, 0): raise ApiDeprecationError(version=protocol.api_level) - else: - bundled_data = getattr(protocol, "bundled_data", {}) - bundled_data.update(extra_data) - gpa_extras = getattr(protocol, "extra_labware", None) or None - context = get_protocol_api( - protocol.api_level, - bundled_labware=getattr(protocol, "bundled_labware", None), - bundled_data=bundled_data, - extra_labware=gpa_extras, + + # Guard against trying to run protocols for the wrong robot type. + # This matches what robot-server does. + if protocol.robot_type != _get_robot_type(): + raise RuntimeError( + f'This robot is of type "{_get_robot_type()}",' + f' so it can\'t execute protocols for robot type "{protocol.robot_type}"' ) + + if protocol.api_level < ENGINE_CORE_API_VERSION: + _run_file_non_pe( + protocol=protocol, + emit_runlog=emit_runlog, + ) + else: + # TODO(mm, 2023-07-06): Once these NotImplementedErrors are resolved, consider removing + # the enclosing if-else block and running everything through _run_file_pe() for simplicity. if emit_runlog: - broker = context.broker - broker.subscribe(command_types.COMMAND, emit_runlog) - context.home() - try: - execute_apiv2.run_protocol(protocol, context) - finally: - context.cleanup() + raise NotImplementedError( + f"Printing the run log is not currently supported for Python protocols" + f" with apiLevel {ENGINE_CORE_API_VERSION} or newer." + f" Pass --no-print-runlog to opentrons_execute" + f" or emit_runlog=None to opentrons.execute.execute()." + ) + if custom_data_paths: + raise NotImplementedError( + f"The custom_data_paths argument is not currently supported for Python protocols" + f" with apiLevel {ENGINE_CORE_API_VERSION} or newer." + ) + protocol_file.seek(0) + _run_file_pe( + protocol_file=protocol_file, + protocol_name=protocol_name, + extra_labware=extra_labware, + hardware_api=_get_global_hardware_controller(_get_robot_type()).wrapped(), + ) def make_runlog_cb() -> Callable[[command_types.CommandMessage], None]: @@ -401,17 +476,198 @@ def main() -> int: stack_logger.addHandler(logging.StreamHandler(sys.stdout)) log_level = args.log_level else: + # TODO(mm, 2023-07-13): This default logging prints error information redundantly + # when executing via Protocol Engine, because Protocol Engine logs when commands fail. log_level = "warning" - # Try to migrate containers from database to v2 format - execute( - protocol_file=args.protocol, - protocol_name=args.protocol.name, - custom_labware_paths=args.custom_labware_path, - custom_data_paths=(args.custom_data_path + args.custom_data_file), - log_level=log_level, - emit_runlog=printer, + + try: + execute( + protocol_file=args.protocol, + protocol_name=args.protocol.name, + custom_labware_paths=args.custom_labware_path, + custom_data_paths=(args.custom_data_path + args.custom_data_file), + log_level=log_level, + emit_runlog=printer, + ) + return 0 + except _ProtocolEngineExecuteError as error: + # _ProtocolEngineExecuteError is a wrapper that's meaningless to the CLI user. + # Take the actual protocol problem out of it and just print that. + print(error.to_stderr_string(), file=sys.stderr) + return 1 + # execute() might raise other exceptions, but we don't have a nice way to print those. + # Just let Python show a traceback. + + +class _ProtocolEngineExecuteError(Exception): + def __init__(self, errors: List[ProtocolEngineErrorOccurrence]) -> None: + """Raised when there was any fatal error running a protocol through Protocol Engine. + + Protocol Engine reports errors as data, not as exceptions. + But the only way for `execute()` to signal problems to its caller is to raise something. + So we need this class to wrap them. + + Params: + errors: The errors that Protocol Engine reported. + """ + # Show the full error details if this is part of a traceback. Don't try to summarize. + super().__init__(errors) + self._error_occurrences = errors + + def to_stderr_string(self) -> str: + """Return a string suitable as the stderr output of the `opentrons_execute` CLI. + + This summarizes from the full error details. + """ + # It's unclear what exactly we should extract here. + # + # First, do we print the first element, or the last, or all of them? + # + # Second, do we print the .detail? .errorCode? .errorInfo? .wrappedErrors? + # By contract, .detail seems like it would be insufficient, but experimentally, + # it includes a lot, like: + # + # ProtocolEngineError [line 3]: Error 4000 GENERAL_ERROR (ProtocolEngineError): + # UnexpectedProtocolError: Labware "fixture_12_trough" not found with version 1 + # in namespace "fixture". + return self._error_occurrences[0].detail + + +def _create_live_context_non_pe( + api_version: APIVersion, + hardware_api: ThreadManagedHardware, + deck_type: str, + extra_labware: Optional[Dict[str, "LabwareDefinitionDict"]], + bundled_labware: Optional[Dict[str, "LabwareDefinitionDict"]], + bundled_data: Optional[Dict[str, bytes]], +) -> ProtocolContext: + """Return a live ProtocolContext. + + This controls the robot through the older infrastructure, instead of through Protocol Engine. + """ + assert api_version < ENGINE_CORE_API_VERSION + return protocol_api.create_protocol_context( + api_version=api_version, + deck_type=deck_type, + hardware_api=hardware_api, + bundled_labware=bundled_labware, + bundled_data=bundled_data, + extra_labware=extra_labware, + ) + + +def _create_live_context_pe( + api_version: APIVersion, + hardware_api: ThreadManagedHardware, + robot_type: RobotType, + deck_type: str, + extra_labware: Dict[str, "LabwareDefinitionDict"], + bundled_data: Optional[Dict[str, bytes]], +) -> ProtocolContext: + """Return a live ProtocolContext that controls the robot through ProtocolEngine.""" + assert api_version >= ENGINE_CORE_API_VERSION + + global _LIVE_PROTOCOL_ENGINE_CONTEXTS + pe, loop = _LIVE_PROTOCOL_ENGINE_CONTEXTS.enter_context( + create_protocol_engine_in_thread( + hardware_api=hardware_api.wrapped(), + config=_get_protocol_engine_config(), + drop_tips_and_home_after=False, + ) ) - return 0 + + # `async def` so we can use loop.run_coroutine_threadsafe() to wait for its completion. + # Non-async would use call_soon_threadsafe(), which makes the waiting harder. + async def add_all_extra_labware() -> None: + for labware_definition_dict in extra_labware.values(): + labware_definition = LabwareDefinition.parse_obj(labware_definition_dict) + pe.add_labware_definition(labware_definition) + + # Add extra_labware to ProtocolEngine, being careful not to modify ProtocolEngine from this + # thread. See concurrency notes in ProtocolEngine docstring. + future = asyncio.run_coroutine_threadsafe(add_all_extra_labware(), loop) + future.result() + + return protocol_api.create_protocol_context( + api_version=api_version, + hardware_api=hardware_api, + deck_type=deck_type, + protocol_engine=pe, + protocol_engine_loop=loop, + bundled_data=bundled_data, + ) + + +def _run_file_non_pe( + protocol: Protocol, + emit_runlog: Optional[_EmitRunlogCallable], +) -> None: + """Run a protocol file without Protocol Engine, with the older infrastructure instead.""" + if isinstance(protocol, PythonProtocol): + extra_labware = protocol.extra_labware + bundled_labware = protocol.bundled_labware + bundled_data = protocol.bundled_data + else: + # JSON protocols do have "bundled labware" embedded in them, but those aren't represented in + # the parsed Protocol object and we don't need to create the ProtocolContext with them. + # execute_apiv2.run_protocol() will pull them out of the JSON and load them into the + # ProtocolContext. + extra_labware = None + bundled_labware = None + bundled_data = None + + context = _create_live_context_non_pe( + api_version=protocol.api_level, + hardware_api=_get_global_hardware_controller(_get_robot_type()), + deck_type=guess_deck_type_from_global_config(), + extra_labware=extra_labware, + bundled_labware=bundled_labware, + bundled_data=bundled_data, + ) + + if emit_runlog: + context.broker.subscribe(command_types.COMMAND, emit_runlog) + + context.home() + try: + execute_apiv2.run_protocol(protocol, context) + finally: + context.cleanup() + + +def _run_file_pe( + protocol_file: Union[BinaryIO, TextIO], + protocol_name: str, + extra_labware: Dict[str, FoundLabware], + hardware_api: HardwareControlAPI, +) -> None: + """Run a protocol file with Protocol Engine.""" + + async def run(protocol_source: ProtocolSource) -> None: + protocol_engine = await create_protocol_engine( + hardware_api=hardware_api, + config=_get_protocol_engine_config(), + ) + + protocol_runner = create_protocol_runner( + protocol_config=protocol_source.config, + protocol_engine=protocol_engine, + hardware_api=hardware_api, + ) + + # TODO(mm, 2023-06-30): This will home and drop tips at the end, which is not how + # things have historically behaved with PAPIv2.13 and older or JSONv5 and older. + result = await protocol_runner.run(protocol_source) + + if result.state_summary.status != EngineStatus.SUCCEEDED: + raise _ProtocolEngineExecuteError(result.state_summary.errors) + + with _adapt_protocol_source( + protocol_file=protocol_file, + protocol_name=protocol_name, + extra_labware=extra_labware, + ) as protocol_source: + asyncio.run(run(protocol_source)) def _get_robot_type() -> RobotType: @@ -419,6 +675,61 @@ def _get_robot_type() -> RobotType: return "OT-3 Standard" if should_use_ot3() else "OT-2 Standard" +def _get_protocol_engine_config() -> Config: + """Return a Protocol Engine config to execute protocols on this device.""" + return Config( + robot_type=_get_robot_type(), + deck_type=DeckType(guess_deck_type_from_global_config()), + # We deliberately omit ignore_pause=True because, in the current implementation of + # opentrons.protocol_api.core.engine, that would incorrectly make + # ProtocolContext.is_simulating() return True. + ) + + +def _get_jupyter_labware() -> Dict[str, FoundLabware]: + """Return labware files in this robot's Jupyter Notebook directory.""" + if IS_ROBOT: + # JUPYTER_NOTEBOOK_LABWARE_DIR should never be None when IS_ROBOT == True. + assert JUPYTER_NOTEBOOK_LABWARE_DIR is not None + if JUPYTER_NOTEBOOK_LABWARE_DIR.is_dir(): + return labware_from_paths([JUPYTER_NOTEBOOK_LABWARE_DIR]) + + return {} + + +@contextlib.contextmanager +def _adapt_protocol_source( + protocol_file: Union[BinaryIO, TextIO], + protocol_name: str, + extra_labware: Dict[str, FoundLabware], +) -> Generator[ProtocolSource, None, None]: + """Create a `ProtocolSource` representing input protocol files.""" + with tempfile.TemporaryDirectory() as temporary_directory: + # It's not well-defined in our customer-facing interfaces whether the supplied protocol_name + # should be just the filename part, or a path with separators. In case it contains stuff + # like "../", sanitize it to just the filename part so we don't save files somewhere bad. + safe_protocol_name = Path(protocol_name).name + + temp_protocol_file = Path(temporary_directory) / safe_protocol_name + + # FIXME(mm, 2023-06-26): Copying this file is pure overhead, and it introduces encoding + # hazards. Remove this when we can parse JSONv6+ and PAPIv2.14+ protocols without going + # through the filesystem. https://opentrons.atlassian.net/browse/RSS-281 + copy_file_like(source=protocol_file, destination=temp_protocol_file) + + custom_labware_files = [labware.path for labware in extra_labware.values()] + + protocol_source = asyncio.run( + ProtocolReader().read_saved( + files=[temp_protocol_file] + custom_labware_files, + directory=None, + files_are_prevalidated=False, + ) + ) + + yield protocol_source + + def _get_global_hardware_controller(robot_type: RobotType) -> ThreadManagedHardware: # Build a hardware controller in a worker thread, which is necessary # because ipython runs its notebook in asyncio but the notebook @@ -446,5 +757,13 @@ def _clear_cached_hardware_controller() -> None: _THREAD_MANAGED_HW = None +# This atexit registration must come after _clear_cached_hardware_controller() +# to ensure we tear things down in order from highest level to lowest level. +@atexit.register +def _clear_live_protocol_engine_contexts() -> None: + global _LIVE_PROTOCOL_ENGINE_CONTEXTS + _LIVE_PROTOCOL_ENGINE_CONTEXTS.close() + + if __name__ == "__main__": sys.exit(main()) diff --git a/api/src/opentrons/protocol_engine/__init__.py b/api/src/opentrons/protocol_engine/__init__.py index 1daaf846f12..a975b497332 100644 --- a/api/src/opentrons/protocol_engine/__init__.py +++ b/api/src/opentrons/protocol_engine/__init__.py @@ -7,7 +7,10 @@ The main interface is the `ProtocolEngine` class. """ -from .create_protocol_engine import create_protocol_engine +from .create_protocol_engine import ( + create_protocol_engine, + create_protocol_engine_in_thread, +) from .protocol_engine import ProtocolEngine from .errors import ProtocolEngineError, ErrorOccurrence from .commands import ( @@ -55,6 +58,7 @@ __all__ = [ # main factory and interface exports "create_protocol_engine", + "create_protocol_engine_in_thread", "ProtocolEngine", "StateSummary", "Config", diff --git a/api/src/opentrons/protocol_engine/create_protocol_engine.py b/api/src/opentrons/protocol_engine/create_protocol_engine.py index cca1669355f..f4e70afc4e7 100644 --- a/api/src/opentrons/protocol_engine/create_protocol_engine.py +++ b/api/src/opentrons/protocol_engine/create_protocol_engine.py @@ -1,13 +1,19 @@ """Main ProtocolEngine factory.""" +import asyncio +import contextlib +import typing + from opentrons.hardware_control import HardwareControlAPI from opentrons.hardware_control.types import DoorState -from opentrons.protocol_engine.resources.module_data_provider import ModuleDataProvider +from opentrons.util.async_helpers import async_context_manager_in_thread from .protocol_engine import ProtocolEngine -from .resources import DeckDataProvider +from .resources import DeckDataProvider, ModuleDataProvider from .state import Config, StateStore +# TODO(mm, 2023-06-16): Arguably, this not being a context manager makes us prone to forgetting to +# clean it up properly, especially in tests. See e.g. https://opentrons.atlassian.net/browse/RSS-222 async def create_protocol_engine( hardware_api: HardwareControlAPI, config: Config, @@ -32,3 +38,54 @@ async def create_protocol_engine( ) return ProtocolEngine(state_store=state_store, hardware_api=hardware_api) + + +@contextlib.contextmanager +def create_protocol_engine_in_thread( + hardware_api: HardwareControlAPI, + config: Config, + drop_tips_and_home_after: bool, +) -> typing.Generator[ + typing.Tuple[ProtocolEngine, asyncio.AbstractEventLoop], None, None +]: + """Run a `ProtocolEngine` in a worker thread. + + When this context manager is entered, it: + + 1. Starts a worker thread. + 2. Starts an asyncio event loop in that worker thread. + 3. Creates and `.play()`s a `ProtocolEngine` in that event loop. + 4. Returns the `ProtocolEngine` and the event loop. + Use functions like `asyncio.run_coroutine_threadsafe()` to safely interact with + the `ProtocolEngine` from your thread. + + When this context manager is exited, it: + + 1. Cleans up the `ProtocolEngine`. + 2. Stops and cleans up the event loop. + 3. Joins the thread. + """ + with async_context_manager_in_thread( + _protocol_engine(hardware_api, config, drop_tips_and_home_after) + ) as ( + protocol_engine, + loop, + ): + yield protocol_engine, loop + + +@contextlib.asynccontextmanager +async def _protocol_engine( + hardware_api: HardwareControlAPI, + config: Config, + drop_tips_and_home_after: bool, +) -> typing.AsyncGenerator[ProtocolEngine, None]: + protocol_engine = await create_protocol_engine( + hardware_api=hardware_api, + config=config, + ) + try: + protocol_engine.play() + yield protocol_engine + finally: + await protocol_engine.finish(drop_tips_and_home=drop_tips_and_home_after) diff --git a/api/src/opentrons/util/async_helpers.py b/api/src/opentrons/util/async_helpers.py index 56606dda468..3e44c11153c 100644 --- a/api/src/opentrons/util/async_helpers.py +++ b/api/src/opentrons/util/async_helpers.py @@ -3,9 +3,21 @@ """ from functools import wraps -from typing import TypeVar, Callable, Awaitable, cast, Any +from threading import Thread +from typing import ( + Any, + AsyncContextManager, + Awaitable, + Callable, + Generator, + Tuple, + TypeVar, + cast, +) import asyncio +import contextlib +import queue async def asyncio_yield() -> None: @@ -36,10 +48,10 @@ async def and await call() that still effectively "block" other concurrent tasks await asyncio.sleep(0) -Wrapped = TypeVar("Wrapped", bound=Callable[..., Awaitable[Any]]) +_Wrapped = TypeVar("_Wrapped", bound=Callable[..., Awaitable[Any]]) -def ensure_yield(async_def_func: Wrapped) -> Wrapped: +def ensure_yield(async_def_func: _Wrapped) -> _Wrapped: """ A decorator that makes sure that asyncio_yield() is called after the decorated async function finishes executing. @@ -57,4 +69,98 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any: await asyncio_yield() return ret - return cast(Wrapped, _wrapper) + return cast(_Wrapped, _wrapper) + + +_ContextManagerResult = TypeVar("_ContextManagerResult") + + +@contextlib.contextmanager +def async_context_manager_in_thread( + async_context_manager: AsyncContextManager[_ContextManagerResult], +) -> Generator[Tuple[_ContextManagerResult, asyncio.AbstractEventLoop], None, None]: + """Enter an async context manager in a worker thread. + + When you enter this context manager, it: + + 1. Spawns a worker thread. + 2. In that thread, starts an asyncio event loop. + 3. In that event loop, enters the context manager that you passed in. + 4. Returns: the result of entering that context manager, and the running event loop. + Use functions like `asyncio.run_coroutine_threadsafe()` to safely interact + with the returned object from your thread. + + When you exit this context manager, it: + + 1. In the worker thread's event loop, exits the context manager that you passed in. + 2. Stops and cleans up the worker thread's event loop. + 3. Joins the worker thread. + """ + with _run_loop_in_thread() as loop_in_thread: + async_object = asyncio.run_coroutine_threadsafe( + async_context_manager.__aenter__(), + loop=loop_in_thread, + ).result() + + try: + yield async_object, loop_in_thread + + finally: + exit = asyncio.run_coroutine_threadsafe( + async_context_manager.__aexit__(None, None, None), + loop=loop_in_thread, + ) + exit.result() + + +@contextlib.contextmanager +def _run_loop_in_thread() -> Generator[asyncio.AbstractEventLoop, None, None]: + """Run an event loop in a worker thread. + + Entering this context manager spawns a thread, starts an asyncio event loop in it, + and returns that loop. + + Exiting this context manager stops and cleans up the event loop, and then joins the thread. + """ + loop_mailbox: "queue.SimpleQueue[asyncio.AbstractEventLoop]" = queue.SimpleQueue() + + def _in_thread() -> None: + loop = asyncio.new_event_loop() + + # We assume that the lines above this will never fail, + # so we will always reach this point to unblock the parent thread. + loop_mailbox.put(loop) + + loop.run_forever() + + # If we've reached here, the loop has been stopped from outside this thread. Clean it up. + # + # This cleanup is naive because asyncio makes it difficult and confusing to get it right. + # Compare this with asyncio.run()'s cleanup, which: + # + # * Cancels and awaits any remaining tasks + # (according to the source code--this seems undocumented) + # * Shuts down asynchronous generators + # (see asyncio.shutdown_asyncgens()) + # * Shuts down the default thread pool executor + # (see https://bugs.python.org/issue34037 and asyncio.shutdown_default_executor()) + # + # In Python >=3.11, we should rewrite this to use asyncio.Runner, + # which can take care of these nuances for us. + loop.close() + + thread = Thread( + target=_in_thread, + name=f"{__name__} event loop thread", + # This is a load-bearing daemon=True. It avoids @atexit-related deadlocks when this is used + # by opentrons.execute and cleaned up by opentrons.execute's @atexit handler. + # https://github.com/Opentrons/opentrons/pull/12970#issuecomment-1648243785 + daemon=True, + ) + thread.start() + loop_in_thread = loop_mailbox.get() + try: + yield loop_in_thread + finally: + loop_in_thread.call_soon_threadsafe(loop_in_thread.stop) + thread.join() diff --git a/api/src/opentrons/util/entrypoint_util.py b/api/src/opentrons/util/entrypoint_util.py index 5625828f5d4..954d837c2f3 100644 --- a/api/src/opentrons/util/entrypoint_util.py +++ b/api/src/opentrons/util/entrypoint_util.py @@ -5,7 +5,8 @@ import logging from json import JSONDecodeError import pathlib -from typing import Dict, Sequence, Union, TYPE_CHECKING +import shutil +from typing import BinaryIO, Dict, Sequence, TextIO, Union, TYPE_CHECKING from jsonschema import ValidationError # type: ignore @@ -83,3 +84,37 @@ def datafiles_from_paths(paths: Sequence[Union[str, pathlib.Path]]) -> Dict[str, else: log.info(f"ignoring {child} in data path") return datafiles + + +# HACK(mm, 2023-06-29): This function is attempting to do something fundamentally wrong. +# Remove it when we fix https://opentrons.atlassian.net/browse/RSS-281. +def copy_file_like(source: Union[BinaryIO, TextIO], destination: pathlib.Path) -> None: + """Copy a file-like object to a path. + + Limitations: + If `source` is text, the new file's encoding may not correctly match its original encoding. + This can matter if it's a Python file and it has an encoding declaration + (https://docs.python.org/3.7/reference/lexical_analysis.html#encoding-declarations). + Also, its newlines may get translated. + """ + # When we read from the source stream, will it give us bytes, or text? + try: + # Experimentally, this is present (but possibly None) on text-mode streams, + # and not present on binary-mode streams. + getattr(source, "encoding") + except AttributeError: + source_is_text = False + else: + source_is_text = True + + if source_is_text: + destination_mode = "wt" + else: + destination_mode = "wb" + + with open( + destination, + mode=destination_mode, + ) as destination_file: + # Use copyfileobj() to limit memory usage. + shutil.copyfileobj(fsrc=source, fdst=destination_file) diff --git a/api/tests/opentrons/async_context_manager_in_thread.py b/api/tests/opentrons/async_context_manager_in_thread.py deleted file mode 100644 index 75f2d982085..00000000000 --- a/api/tests/opentrons/async_context_manager_in_thread.py +++ /dev/null @@ -1,98 +0,0 @@ -"""A test helper to enter an async context manager in a worker thread.""" - -from __future__ import annotations - -import asyncio -import contextlib -import queue -import typing - -from concurrent.futures import ThreadPoolExecutor - - -_T = typing.TypeVar("_T") - - -@contextlib.contextmanager -def async_context_manager_in_thread( - async_context_manager: typing.AsyncContextManager[_T], -) -> typing.Generator[typing.Tuple[_T, asyncio.AbstractEventLoop], None, None]: - """Enter an async context manager in a worker thread. - - When you enter this context manager, it: - - 1. Spawns a worker thread. - 2. In that thread, starts an asyncio event loop. - 3. In that event loop, enters the context manager that you passed in. - 4. Returns: the result of entering that context manager, and the running event loop. - Use functions like `asyncio.run_coroutine_threadsafe()` to safely interact - with the returned object from your thread. - - When you exit this context manager, it: - - 1. In the worker thread's event loop, exits the context manager that you passed in. - 2. Stops and cleans up the worker thread's event loop. - 3. Joins the worker thread. - """ - with _run_loop_in_thread() as loop_in_thread: - async_object = asyncio.run_coroutine_threadsafe( - async_context_manager.__aenter__(), - loop=loop_in_thread, - ).result() - - try: - yield async_object, loop_in_thread - - finally: - exit = asyncio.run_coroutine_threadsafe( - async_context_manager.__aexit__(None, None, None), - loop=loop_in_thread, - ) - exit.result() - - -@contextlib.contextmanager -def _run_loop_in_thread() -> typing.Generator[asyncio.AbstractEventLoop, None, None]: - """Run an event loop in a worker thread. - - Entering this context manager spawns a thread, starts an asyncio event loop in it, - and returns that loop. - - Exiting this context manager stops and cleans up the event loop, and then joins the thread. - """ - loop_queue: "queue.SimpleQueue[asyncio.AbstractEventLoop]" = queue.SimpleQueue() - - def _in_thread() -> None: - loop = asyncio.new_event_loop() - - # We assume that the lines above this will never fail, - # so we will always reach this point to unblock the parent thread. - loop_queue.put(loop) - - loop.run_forever() - - # If we've reached here, the loop has been stopped from outside this thread. Clean it up. - # - # This cleanup is naive because asyncio makes it difficult and confusing to get it right. - # Compare this with asyncio.run()'s cleanup, which: - # - # * Cancels and awaits any remaining tasks - # (according to the source code--this seems undocumented) - # * Shuts down asynchronous generators - # (see asyncio.shutdown_asyncgens()) - # * Shuts down the default thread pool executor - # (see https://bugs.python.org/issue34037 and asyncio.shutdown_default_executor()) - # - # In Python >=3.11, we should rewrite this to use asyncio.Runner, - # which can take care of these nuances for us. - loop.close() - - with ThreadPoolExecutor(max_workers=1) as executor: - executor.submit(_in_thread) - - loop_in_thread = loop_queue.get() - - try: - yield loop_in_thread - finally: - loop_in_thread.call_soon_threadsafe(loop_in_thread.stop) diff --git a/api/tests/opentrons/conftest.py b/api/tests/opentrons/conftest.py index 2254792ff78..c2b520ad727 100755 --- a/api/tests/opentrons/conftest.py +++ b/api/tests/opentrons/conftest.py @@ -50,13 +50,16 @@ ) from opentrons.protocol_api import ProtocolContext, Labware, create_protocol_context from opentrons.protocol_api.core.legacy.legacy_labware_core import LegacyLabwareCore +from opentrons.protocol_engine import ( + create_protocol_engine_in_thread, + Config as ProtocolEngineConfig, + DeckType, +) from opentrons.protocols.api_support import deck_type from opentrons.protocols.api_support.types import APIVersion from opentrons.protocols.api_support.definitions import MAX_SUPPORTED_VERSION from opentrons.types import Location, Point -from .protocol_engine_in_thread import protocol_engine_in_thread - if TYPE_CHECKING: from opentrons.drivers.smoothie_drivers import SmoothieDriver as SmoothieDriverType @@ -282,7 +285,22 @@ def _make_ot3_pe_ctx( deck_type: str, ) -> Generator[ProtocolContext, None, None]: """Return a ProtocolContext configured for an OT-3 and backed by Protocol Engine.""" - with protocol_engine_in_thread(hardware=hardware) as (engine, loop): + with create_protocol_engine_in_thread( + hardware_api=hardware.wrapped(), + config=ProtocolEngineConfig( + robot_type="OT-3 Standard", + deck_type=DeckType.OT3_STANDARD, + ignore_pause=True, + use_virtual_pipettes=True, + use_virtual_modules=True, + use_virtual_gripper=True, + block_on_door_open=False, + ), + drop_tips_and_home_after=False, + ) as ( + engine, + loop, + ): yield create_protocol_context( api_version=MAX_SUPPORTED_VERSION, hardware_api=hardware, diff --git a/api/tests/opentrons/protocol_engine_in_thread.py b/api/tests/opentrons/protocol_engine_in_thread.py deleted file mode 100644 index ec7ca1b21f6..00000000000 --- a/api/tests/opentrons/protocol_engine_in_thread.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Run a `ProtocolEngine` in a worker thread.""" - -import asyncio -import contextlib -import typing - -from opentrons.hardware_control import ThreadManagedHardware -from opentrons.protocol_engine import ( - create_protocol_engine, - ProtocolEngine, - Config, - DeckType, -) - -from .async_context_manager_in_thread import async_context_manager_in_thread - - -@contextlib.contextmanager -def protocol_engine_in_thread( - hardware: ThreadManagedHardware, -) -> typing.Generator[ - typing.Tuple[ProtocolEngine, asyncio.AbstractEventLoop], None, None -]: - """Run a `ProtocolEngine` in a worker thread. - - When this context manager is entered, it: - - 1. Starts a worker thread. - 2. Starts an asyncio event loop in that worker thread. - 3. Creates and `.play()`s a `ProtocolEngine` in that event loop. - 4. Returns the `ProtocolEngine` and the event loop. - Use functions like `asyncio.run_coroutine_threadsafe()` to safely interact with - the `ProtocolEngine` from your thread. - - When this context manager is exited, it: - - 1. Cleans up the `ProtocolEngine`. - 2. Stops and cleans up the event loop. - 3. Joins the thread. - """ - with async_context_manager_in_thread(_protocol_engine(hardware)) as ( - protocol_engine, - loop, - ): - yield protocol_engine, loop - - -@contextlib.asynccontextmanager -async def _protocol_engine( - hardware: ThreadManagedHardware, -) -> typing.AsyncGenerator[ProtocolEngine, None]: - protocol_engine = await create_protocol_engine( - hardware_api=hardware.wrapped(), - config=Config( - robot_type="OT-3 Standard", - deck_type=DeckType.OT3_STANDARD, - ignore_pause=True, - use_virtual_pipettes=True, - use_virtual_modules=True, - use_virtual_gripper=True, - block_on_door_open=False, - ), - ) - try: - protocol_engine.play() - yield protocol_engine - finally: - await protocol_engine.finish() diff --git a/api/tests/opentrons/test_async_context_manager_in_thread.py b/api/tests/opentrons/test_async_context_manager_in_thread.py deleted file mode 100644 index 9eaf63c438c..00000000000 --- a/api/tests/opentrons/test_async_context_manager_in_thread.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Tests for the `async_context_manager_in_thread` helper.""" - - -import asyncio - -import pytest - -from .async_context_manager_in_thread import async_context_manager_in_thread - - -def test_enters_and_exits() -> None: - """It should enter and exit the given context manager appropriately, and return its result.""" - - class ContextManager: - def __init__(self) -> None: - self.entered = False - self.exited = False - - async def __aenter__(self) -> str: - self.entered = True - return "Yay!" - - async def __aexit__( - self, exc_type: object, exc_val: object, exc_tb: object - ) -> None: - self.exited = True - - context_manager = ContextManager() - - assert not context_manager.entered - assert not context_manager.exited - - with async_context_manager_in_thread(context_manager) as (result, _): - assert context_manager.entered - assert not context_manager.exited - assert result == "Yay!" - - assert context_manager.exited - - -def test_returns_matching_loop() -> None: - """It should return the event loop that the given context manager is running in.""" - - class ContextManager: - async def __aenter__(self) -> asyncio.AbstractEventLoop: - return asyncio.get_running_loop() - - async def __aexit__( - self, exc_type: object, exc_val: object, exc_tb: object - ) -> None: - pass - - context_manager = ContextManager() - with async_context_manager_in_thread(context_manager) as (result, loop_in_thread): - assert result is loop_in_thread - - -def test_loop_lifetime() -> None: - """Test the lifetime of the returned event loop. - - While the context manager is open, the event loop should be running and usable. - After the context manager closes, the event loop should be closed and unusable. - """ - - class NoOp: - async def __aenter__(self) -> None: - return None - - async def __aexit__( - self, exc_type: object, exc_val: object, exc_tb: object - ) -> None: - pass - - with async_context_manager_in_thread(NoOp()) as (_, loop_in_thread): - asyncio.run_coroutine_threadsafe(asyncio.sleep(0.000001), loop_in_thread) - - with pytest.raises(RuntimeError, match="Event loop is closed"): - loop_in_thread.call_soon_threadsafe(lambda: None) - - -def test_propagates_exception_from_enter() -> None: - """If the given context manager raises an exception when it's entered, it should propagate.""" - - class RaiseExceptionOnEnter: - async def __aenter__(self) -> None: - raise RuntimeError("Oh the humanity.") - - async def __aexit__( - self, exc_type: object, exc_val: object, exc_tb: object - ) -> None: - assert False, "We should not reach here." - - context_manager = RaiseExceptionOnEnter() - with pytest.raises(RuntimeError, match="Oh the humanity"): - with async_context_manager_in_thread(context_manager): - assert False, "We should not reach here." - - -def test_propagates_exception_from_exit() -> None: - """If the given context manager raises an exception when it's exited, it should propagate.""" - - class RaiseExceptionOnExit: - async def __aenter__(self) -> None: - return None - - async def __aexit__( - self, exc_type: object, exc_val: object, exc_tb: object - ) -> None: - raise RuntimeError("Oh the humanity.") - - context_manager = RaiseExceptionOnExit() - with pytest.raises(RuntimeError, match="Oh the humanity"): - with async_context_manager_in_thread(context_manager): - assert False, "We should not reach here." diff --git a/api/tests/opentrons/test_execute.py b/api/tests/opentrons/test_execute.py index e5d829b2ba6..e986fc1ed7c 100644 --- a/api/tests/opentrons/test_execute.py +++ b/api/tests/opentrons/test_execute.py @@ -16,9 +16,11 @@ pipette_load_name_conversions as pipette_load_name, load_data as load_pipette_data, ) + from opentrons import execute, types -from opentrons.protocols.api_support.types import APIVersion from opentrons.hardware_control import Controller, api +from opentrons.protocol_api.core.engine import ENGINE_CORE_API_VERSION +from opentrons.protocols.api_support.types import APIVersion if TYPE_CHECKING: from tests.opentrons.conftest import Bundle, Protocol @@ -27,13 +29,7 @@ HERE = Path(__file__).parent -@pytest.fixture( - params=[ - APIVersion(2, 0), - # TODO(mm, 2023-07-14): Enable this for https://opentrons.atlassian.net/browse/RSS-268. - # ENGINE_CORE_API_VERSION, - ] -) +@pytest.fixture(params=[APIVersion(2, 0), ENGINE_CORE_API_VERSION]) def api_version(request: pytest.FixtureRequest) -> APIVersion: """Return an API version to test with. @@ -63,12 +59,15 @@ async def dummy_delay(self: Any, duration_s: float) -> None: @pytest.mark.parametrize( - "protocol_file", + ("protocol_file", "expect_run_log"), [ - "testosaur_v2.py", - # TODO(mm, 2023-07-14): Resolve this xfail. https://opentrons.atlassian.net/browse/RSS-268 + ("testosaur_v2.py", True), + ("testosaur_v2_14.py", False), + # FIXME(mm, 2023-07-20): Support printing the run log when executing new protocols. + # Then, remove this expect_run_log parametrization (it should always be True). pytest.param( "testosaur_v2_14.py", + True, marks=pytest.mark.xfail(strict=True, raises=NotImplementedError), ), ], @@ -76,7 +75,7 @@ async def dummy_delay(self: Any, duration_s: float) -> None: def test_execute_function_apiv2( protocol: Protocol, protocol_file: str, - monkeypatch: pytest.MonkeyPatch, + expect_run_log: bool, virtual_smoothie_env: None, mock_get_attached_instr: mock.AsyncMock, ) -> None: @@ -110,13 +109,21 @@ def emit_runlog(entry: Any) -> None: nonlocal entries entries.append(entry) - execute.execute(protocol.filelike, protocol.filename, emit_runlog=emit_runlog) - assert [item["payload"]["text"] for item in entries if item["$"] == "before"] == [ - "Picking up tip from A1 of Opentrons 96 Tip Rack 1000 µL on 1", - "Aspirating 100.0 uL from A1 of Corning 96 Well Plate 360 µL Flat on 2 at 500.0 uL/sec", - "Dispensing 100.0 uL into B1 of Corning 96 Well Plate 360 µL Flat on 2 at 1000.0 uL/sec", - "Dropping tip into H12 of Opentrons 96 Tip Rack 1000 µL on 1", - ] + execute.execute( + protocol.filelike, + protocol.filename, + emit_runlog=(emit_runlog if expect_run_log else None), + ) + + if expect_run_log: + assert [ + item["payload"]["text"] for item in entries if item["$"] == "before" + ] == [ + "Picking up tip from A1 of Opentrons 96 Tip Rack 1000 µL on 1", + "Aspirating 100.0 uL from A1 of Corning 96 Well Plate 360 µL Flat on 2 at 500.0 uL/sec", + "Dispensing 100.0 uL into B1 of Corning 96 Well Plate 360 µL Flat on 2 at 1000.0 uL/sec", + "Dropping tip into H12 of Opentrons 96 Tip Rack 1000 µL on 1", + ] def test_execute_function_json_v3( diff --git a/api/tests/opentrons/util/test_async_helpers.py b/api/tests/opentrons/util/test_async_helpers.py new file mode 100644 index 00000000000..14f9e1a0436 --- /dev/null +++ b/api/tests/opentrons/util/test_async_helpers.py @@ -0,0 +1,126 @@ +import asyncio + +import pytest + +from opentrons.util import async_helpers as subject + + +class TestAsyncContextManagerInThread: + """Tests for `async_context_manager_in_thread()`.""" + + @staticmethod + def test_enters_and_exits() -> None: + """It should enter and exit the given context manager appropriately, and return its result.""" + + class ContextManager: + def __init__(self) -> None: + self.entered = False + self.exited = False + + async def __aenter__(self) -> str: + self.entered = True + return "Yay!" + + async def __aexit__( + self, exc_type: object, exc_val: object, exc_tb: object + ) -> None: + self.exited = True + + context_manager = ContextManager() + + assert not context_manager.entered + assert not context_manager.exited + + with subject.async_context_manager_in_thread(context_manager) as (result, _): + assert context_manager.entered + assert not context_manager.exited + assert result == "Yay!" + + assert context_manager.exited + + @staticmethod + def test_returns_matching_loop() -> None: + """It should return the event loop that the given context manager is running in.""" + + class ContextManager: + async def __aenter__(self) -> asyncio.AbstractEventLoop: + return asyncio.get_running_loop() + + async def __aexit__( + self, exc_type: object, exc_val: object, exc_tb: object + ) -> None: + pass + + context_manager = ContextManager() + with subject.async_context_manager_in_thread(context_manager) as ( + result, + loop_in_thread, + ): + assert result is loop_in_thread + + @staticmethod + def test_loop_lifetime() -> None: + """Test the lifetime of the returned event loop. + + While the context manager is open, the event loop should be running and usable. + After the context manager closes, the event loop should be closed and unusable. + """ + + class NoOp: + async def __aenter__(self) -> None: + return None + + async def __aexit__( + self, exc_type: object, exc_val: object, exc_tb: object + ) -> None: + pass + + with subject.async_context_manager_in_thread(NoOp()) as (_, loop_in_thread): + # As a smoke test to see if the event loop is running and usable, + # run an arbitrary coroutine and wait for it to finish. + ( + asyncio.run_coroutine_threadsafe( + asyncio.sleep(0.000001), loop_in_thread + ) + ).result() + + # The loop should be closed and unusable now that the context manager has exited. + assert loop_in_thread.is_closed + with pytest.raises(RuntimeError, match="Event loop is closed"): + loop_in_thread.call_soon_threadsafe(lambda: None) + + @staticmethod + def test_propagates_exception_from_enter() -> None: + """If the given context manager raises an exception when it's entered, it should propagate.""" + + class RaiseExceptionOnEnter: + async def __aenter__(self) -> None: + raise RuntimeError("Oh the humanity.") + + async def __aexit__( + self, exc_type: object, exc_val: object, exc_tb: object + ) -> None: + assert False, "We should not reach here." + + context_manager = RaiseExceptionOnEnter() + with pytest.raises(RuntimeError, match="Oh the humanity"): + with subject.async_context_manager_in_thread(context_manager): + assert False, "We should not reach here." + + @staticmethod + def test_propagates_exception_from_exit() -> None: + """If the given context manager raises an exception when it's exited, it should propagate.""" + + class RaiseExceptionOnExit: + async def __aenter__(self) -> None: + return None + + async def __aexit__( + self, exc_type: object, exc_val: object, exc_tb: object + ) -> None: + raise RuntimeError("Oh the humanity.") + + context_manager = RaiseExceptionOnExit() + with pytest.raises(RuntimeError, match="Oh the humanity"): + with subject.async_context_manager_in_thread(context_manager): + pass diff --git a/api/tests/opentrons/util/test_entrypoint_utils.py b/api/tests/opentrons/util/test_entrypoint_util.py similarity index 54% rename from api/tests/opentrons/util/test_entrypoint_utils.py rename to api/tests/opentrons/util/test_entrypoint_util.py index c30351dec3b..4e82ce0e80f 100644 --- a/api/tests/opentrons/util/test_entrypoint_utils.py +++ b/api/tests/opentrons/util/test_entrypoint_util.py @@ -1,13 +1,17 @@ +import io import json import os from pathlib import Path from typing import Callable +import pytest + from opentrons_shared_data.labware.dev_types import LabwareDefinition as LabwareDefDict from opentrons.util.entrypoint_util import ( FoundLabware, labware_from_paths, datafiles_from_paths, + copy_file_like, ) @@ -75,3 +79,70 @@ def test_datafiles_from_paths(tmp_path: Path) -> None: "test1": "wait theres a second file???".encode(), "test-file": "this isnt even in a directory".encode(), } + + +class TestCopyFileLike: + """Tests for `copy_file_like()`.""" + + @pytest.fixture(params=["abc", "µ"]) + def source_text(self, request: pytest.FixtureRequest) -> str: + return request.param # type: ignore[attr-defined,no-any-return] + + @pytest.fixture + def source_bytes(self, source_text: str) -> bytes: + return b"\x00\x01\x02\x03\x04" + + @pytest.fixture + def source_path(self, tmp_path: Path) -> Path: + return tmp_path / "source" + + @pytest.fixture + def destination_path(self, tmp_path: Path) -> Path: + return tmp_path / "destination" + + def test_from_text_file( + self, + source_text: str, + source_path: Path, + destination_path: Path, + ) -> None: + """Test that it correctly copies from a text-mode `open()`.""" + source_path.write_text(source_text) + + with open( + source_path, + mode="rt", + ) as source_file: + copy_file_like(source=source_file, destination=destination_path) + + assert destination_path.read_text() == source_text + + def test_from_binary_file( + self, + source_bytes: bytes, + source_path: Path, + destination_path: Path, + ) -> None: + """Test that it correctly copies from a binary-mode `open()`.""" + source_path.write_bytes(source_bytes) + + with open(source_path, mode="rb") as source_file: + copy_file_like(source=source_file, destination=destination_path) + + assert destination_path.read_bytes() == source_bytes + + def test_from_stringio(self, source_text: str, destination_path: Path) -> None: + """Test that it correctly copies from an `io.StringIO`.""" + stringio = io.StringIO(source_text) + + copy_file_like(source=stringio, destination=destination_path) + + assert destination_path.read_text() == source_text + + def test_from_bytesio(self, source_bytes: bytes, destination_path: Path) -> None: + """Test that it correctly copies from an `io.BytesIO`.""" + bytesio = io.BytesIO(source_bytes) + + copy_file_like(source=bytesio, destination=destination_path) + + assert destination_path.read_bytes() == source_bytes