From 3cf9293c0d645a696a677c967dc6e3f5303f561c Mon Sep 17 00:00:00 2001 From: Max Marrone Date: Tue, 30 Apr 2024 14:47:24 -0400 Subject: [PATCH] refactor(api): Make sure command implementations return something compatible with the command's result type (#15051) --- .../protocol_engine/commands/command.py | 37 +++++++++++-------- .../execution/test_command_executor.py | 8 +--- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/api/src/opentrons/protocol_engine/commands/command.py b/api/src/opentrons/protocol_engine/commands/command.py index ad43128236d..fcdd7387355 100644 --- a/api/src/opentrons/protocol_engine/commands/command.py +++ b/api/src/opentrons/protocol_engine/commands/command.py @@ -13,6 +13,8 @@ TypeVar, Tuple, List, + Type, + Union, ) from pydantic import BaseModel, Field @@ -29,11 +31,11 @@ from ..state import StateView -CommandParamsT = TypeVar("CommandParamsT", bound=BaseModel) - -CommandResultT = TypeVar("CommandResultT", bound=BaseModel) - -CommandPrivateResultT = TypeVar("CommandPrivateResultT") +_ParamsT = TypeVar("_ParamsT", bound=BaseModel) +_ParamsT_contra = TypeVar("_ParamsT_contra", bound=BaseModel, contravariant=True) +_ResultT = TypeVar("_ResultT", bound=BaseModel) +_ResultT_co = TypeVar("_ResultT_co", bound=BaseModel, covariant=True) +_PrivateResultT_co = TypeVar("_PrivateResultT_co", covariant=True) class CommandStatus(str, Enum): @@ -58,7 +60,7 @@ class CommandIntent(str, Enum): FIXIT = "fixit" -class BaseCommandCreate(GenericModel, Generic[CommandParamsT]): +class BaseCommandCreate(GenericModel, Generic[_ParamsT]): """Base class for command creation requests. You shouldn't use this class directly; instead, use or define @@ -72,7 +74,7 @@ class BaseCommandCreate(GenericModel, Generic[CommandParamsT]): "execution behavior" ), ) - params: CommandParamsT = Field(..., description="Command execution data payload") + params: _ParamsT = Field(..., description="Command execution data payload") intent: Optional[CommandIntent] = Field( None, description=( @@ -97,7 +99,7 @@ class BaseCommandCreate(GenericModel, Generic[CommandParamsT]): ) -class BaseCommand(GenericModel, Generic[CommandParamsT, CommandResultT]): +class BaseCommand(GenericModel, Generic[_ParamsT, _ResultT]): """Base command model. You shouldn't use this class directly; instead, use or define @@ -127,8 +129,8 @@ class BaseCommand(GenericModel, Generic[CommandParamsT, CommandResultT]): ), ) status: CommandStatus = Field(..., description="Command execution status") - params: CommandParamsT = Field(..., description="Command execution data payload") - result: Optional[CommandResultT] = Field( + params: _ParamsT = Field(..., description="Command execution data payload") + result: Optional[_ResultT] = Field( None, description="Command execution result data, if succeeded", ) @@ -167,10 +169,15 @@ class BaseCommand(GenericModel, Generic[CommandParamsT, CommandResultT]): ), ) + _ImplementationCls: Union[ + Type[AbstractCommandImpl[_ParamsT, _ResultT]], + Type[AbstractCommandWithPrivateResultImpl[_ParamsT, _ResultT, object]], + ] + class AbstractCommandImpl( ABC, - Generic[CommandParamsT, CommandResultT], + Generic[_ParamsT_contra, _ResultT_co], ): """Abstract command creation and execution implementation. @@ -204,14 +211,14 @@ def __init__( pass @abstractmethod - async def execute(self, params: CommandParamsT) -> CommandResultT: + async def execute(self, params: _ParamsT_contra) -> _ResultT_co: """Execute the command, mapping data from execution into a response model.""" ... class AbstractCommandWithPrivateResultImpl( ABC, - Generic[CommandParamsT, CommandResultT, CommandPrivateResultT], + Generic[_ParamsT_contra, _ResultT_co, _PrivateResultT_co], ): """Abstract command creation and execution implementation if the command has private results. @@ -247,7 +254,7 @@ def __init__( @abstractmethod async def execute( - self, params: CommandParamsT - ) -> Tuple[CommandResultT, CommandPrivateResultT]: + self, params: _ParamsT_contra + ) -> Tuple[_ResultT_co, _PrivateResultT_co]: """Execute the command, mapping data from execution into a response model.""" ... diff --git a/api/tests/opentrons/protocol_engine/execution/test_command_executor.py b/api/tests/opentrons/protocol_engine/execution/test_command_executor.py index 1cdb051164c..8f4433a9ebe 100644 --- a/api/tests/opentrons/protocol_engine/execution/test_command_executor.py +++ b/api/tests/opentrons/protocol_engine/execution/test_command_executor.py @@ -242,9 +242,7 @@ class _TestCommand(BaseCommand[_TestCommandParams, _TestCommandResult]): params: _TestCommandParams result: Optional[_TestCommandResult] - @property - def _ImplementationCls(self) -> Type[_TestCommandImpl]: - return TestCommandImplCls + _ImplementationCls: Type[_TestCommandImpl] = TestCommandImplCls command_params = _TestCommandParams() command_result = _TestCommandResult() @@ -407,9 +405,7 @@ class _TestCommand(BaseCommand[_TestCommandParams, _TestCommandResult]): params: _TestCommandParams result: Optional[_TestCommandResult] - @property - def _ImplementationCls(self) -> Type[_TestCommandImpl]: - return TestCommandImplCls + _ImplementationCls: Type[_TestCommandImpl] = TestCommandImplCls command_params = _TestCommandParams()