diff --git a/src/prefect/agent.py b/src/prefect/agent.py index e5bf7b4f3803..d585a0300655 100644 --- a/src/prefect/agent.py +++ b/src/prefect/agent.py @@ -219,7 +219,7 @@ async def check_for_cancelled_flow_runs(self): async for work_queue in self.get_work_queues(): work_queue_names.add(work_queue.name) - cancelling_flow_runs = await self.client.read_flow_runs( + named_cancelling_flow_runs = await self.client.read_flow_runs( flow_run_filter=FlowRunFilter( state=FlowRunFilterState( type=FlowRunFilterStateType(any_=[StateType.CANCELLED]), @@ -231,6 +231,19 @@ async def check_for_cancelled_flow_runs(self): ), ) + typed_cancelling_flow_runs = await self.client.read_flow_runs( + flow_run_filter=FlowRunFilter( + state=FlowRunFilterState( + type=FlowRunFilterStateType(any_=[StateType.CANCELLING]), + ), + work_queue_name=FlowRunFilterWorkQueueName(any_=list(work_queue_names)), + # Avoid duplicate cancellation calls + id=FlowRunFilterId(not_any_=list(self.cancelling_flow_run_ids)), + ), + ) + + cancelling_flow_runs = named_cancelling_flow_runs + typed_cancelling_flow_runs + if cancelling_flow_runs: self.logger.info( f"Found {len(cancelling_flow_runs)} flow runs awaiting cancellation." diff --git a/src/prefect/cli/flow_run.py b/src/prefect/cli/flow_run.py index 808201f75a2c..de00d6b48216 100644 --- a/src/prefect/cli/flow_run.py +++ b/src/prefect/cli/flow_run.py @@ -122,7 +122,7 @@ async def delete(id: UUID): async def cancel(id: UUID): """Cancel a flow fun by ID.""" async with get_client() as client: - cancelling_state = State(type=StateType.CANCELLED, name="Cancelling") + cancelling_state = State(type=StateType.CANCELLING) try: result = await client.set_flow_run_state( flow_run_id=id, state=cancelling_state diff --git a/src/prefect/states.py b/src/prefect/states.py index 50d90bf0e707..db6350ea2b45 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -498,6 +498,15 @@ def Crashed(cls: Type[State] = State, **kwargs) -> State: return schemas.states.Crashed(cls=cls, **kwargs) +def Cancelling(cls: Type[State] = State, **kwargs) -> State: + """Convenience function for creating `Cancelling` states. + + Returns: + State: a Cancelling state + """ + return schemas.states.Cancelling(cls=cls, **kwargs) + + def Cancelled(cls: Type[State] = State, **kwargs) -> State: """Convenience function for creating `Cancelled` states. diff --git a/tests/agent/test_agent_run_cancellation.py b/tests/agent/test_agent_run_cancellation.py index cfb0896aac11..1e6d2bec1d92 100644 --- a/tests/agent/test_agent_run_cancellation.py +++ b/tests/agent/test_agent_run_cancellation.py @@ -11,11 +11,15 @@ from prefect.infrastructure.base import Infrastructure from prefect.orion.database.orm_models import ORMDeployment from prefect.orion.schemas.core import Deployment -from prefect.states import Cancelled, Completed, Pending, Running, Scheduled +from prefect.states import Cancelled, Cancelling, Completed, Pending, Running, Scheduled from prefect.testing.utilities import AsyncMock from prefect.utilities.dispatch import get_registry_for_type +def legacy_named_cancelling_state(**kwargs): + return Cancelled(name="Cancelling", **kwargs) + + async def _create_test_deployment_from_orm( orion_client: OrionClient, orm_deployment: ORMDeployment, **kwargs ) -> Deployment: @@ -42,12 +46,15 @@ async def _create_test_deployment_from_orm( # Test cancellation is called for the correct flow runs ------------------------------- +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_called_for_cancelling_run( - orion_client: OrionClient, deployment: ORMDeployment + orion_client: OrionClient, deployment: ORMDeployment, cancelling_constructor ): flow_run = await orion_client.create_flow_run_from_deployment( deployment.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) async with OrionAgent( @@ -89,15 +96,20 @@ async def test_agent_cancel_run_not_called_for_other_states( agent.cancel_run.assert_not_called() +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_called_for_cancelling_run_with_multiple_work_queues( - orion_client: OrionClient, deployment: ORMDeployment + orion_client: OrionClient, + deployment: ORMDeployment, + cancelling_constructor, ): deployment.work_queue_name = "foo" await orion_client.update_deployment(deployment) flow_run = await orion_client.create_flow_run_from_deployment( deployment.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) async with OrionAgent(work_queues=["foo", "bar"], prefetch_seconds=10) as agent: @@ -107,8 +119,13 @@ async def test_agent_cancel_run_called_for_cancelling_run_with_multiple_work_que agent.cancel_run.assert_awaited_once_with(flow_run) +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_called_for_each_cancelling_run_in_multiple_work_queues( - orion_client: OrionClient, deployment: ORMDeployment + orion_client: OrionClient, + deployment: ORMDeployment, + cancelling_constructor, ): deployment_foo = await _create_test_deployment_from_orm( orion_client, deployment, work_queue_name="foo" @@ -119,11 +136,11 @@ async def test_agent_cancel_run_called_for_each_cancelling_run_in_multiple_work_ flow_run_foo = await orion_client.create_flow_run_from_deployment( deployment_foo.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) flow_run_bar = await orion_client.create_flow_run_from_deployment( deployment_bar.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) async with OrionAgent(work_queues=["foo", "bar"], prefetch_seconds=10) as agent: @@ -135,8 +152,11 @@ async def test_agent_cancel_run_called_for_each_cancelling_run_in_multiple_work_ ) +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_called_for_each_cancelling_run_in_a_work_queue( - orion_client: OrionClient, deployment: ORMDeployment + orion_client: OrionClient, deployment: ORMDeployment, cancelling_constructor ): deployment_foo = await _create_test_deployment_from_orm( orion_client, deployment, work_queue_name="foo" @@ -144,15 +164,15 @@ async def test_agent_cancel_run_called_for_each_cancelling_run_in_a_work_queue( flow_run_1 = await orion_client.create_flow_run_from_deployment( deployment_foo.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) flow_run_2 = await orion_client.create_flow_run_from_deployment( deployment_foo.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) flow_run_3 = await orion_client.create_flow_run_from_deployment( deployment_foo.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) async with OrionAgent(work_queues=["foo"], prefetch_seconds=10) as agent: @@ -164,12 +184,15 @@ async def test_agent_cancel_run_called_for_each_cancelling_run_in_a_work_queue( ) +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_not_called_for_other_work_queues( - orion_client: OrionClient, deployment + orion_client: OrionClient, deployment, cancelling_constructor ): await orion_client.create_flow_run_from_deployment( deployment.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) async with OrionAgent( @@ -201,14 +224,18 @@ def mock_infrastructure_kill(monkeypatch) -> Generator[AsyncMock, None, None]: yield mock +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_kills_run_with_infrastructure_pid( orion_client: OrionClient, deployment: ORMDeployment, mock_infrastructure_kill: AsyncMock, + cancelling_constructor, ): flow_run = await orion_client.create_flow_run_from_deployment( deployment.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) await orion_client.update_flow_run(flow_run.id, infrastructure_pid="test") @@ -221,15 +248,19 @@ async def test_agent_cancel_run_kills_run_with_infrastructure_pid( mock_infrastructure_kill.assert_awaited_once_with("test") +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_with_missing_infrastructure_pid( orion_client: OrionClient, deployment: ORMDeployment, mock_infrastructure_kill: AsyncMock, caplog, + cancelling_constructor, ): flow_run = await orion_client.create_flow_run_from_deployment( deployment.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) async with OrionAgent( @@ -252,13 +283,17 @@ async def test_agent_cancel_run_with_missing_infrastructure_pid( @pytest.mark.usefixtures("mock_infrastructure_kill") +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_updates_state_name( orion_client: OrionClient, deployment: ORMDeployment, + cancelling_constructor, ): flow_run = await orion_client.create_flow_run_from_deployment( deployment.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) await orion_client.update_flow_run(flow_run.id, infrastructure_pid="test") @@ -273,15 +308,19 @@ async def test_agent_cancel_run_updates_state_name( @pytest.mark.usefixtures("mock_infrastructure_kill") +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_preserves_other_state_properties( orion_client: OrionClient, deployment: ORMDeployment, + cancelling_constructor, ): expected_changed_fields = {"name", "timestamp", "id"} flow_run = await orion_client.create_flow_run_from_deployment( deployment.id, - state=Cancelled(name="Cancelling", message="test"), + state=cancelling_constructor(message="test"), ) await orion_client.update_flow_run(flow_run.id, infrastructure_pid="test") @@ -297,15 +336,19 @@ async def test_agent_cancel_run_preserves_other_state_properties( ) == flow_run.state.dict(exclude=expected_changed_fields) +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_with_infrastructure_not_available_during_kill( orion_client: OrionClient, deployment: ORMDeployment, mock_infrastructure_kill: AsyncMock, caplog, + cancelling_constructor, ): flow_run = await orion_client.create_flow_run_from_deployment( deployment.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) await orion_client.update_flow_run(flow_run.id, infrastructure_pid="test") @@ -333,15 +376,19 @@ async def test_agent_cancel_run_with_infrastructure_not_available_during_kill( assert post_flow_run.state.message is None +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_with_infrastructure_not_found_during_kill( orion_client: OrionClient, deployment: ORMDeployment, mock_infrastructure_kill: AsyncMock, caplog, + cancelling_constructor, ): flow_run = await orion_client.create_flow_run_from_deployment( deployment.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) await orion_client.update_flow_run(flow_run.id, infrastructure_pid="test") @@ -368,15 +415,19 @@ async def test_agent_cancel_run_with_infrastructure_not_found_during_kill( assert post_flow_run.state.message is None +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_with_unknown_error_during_kill( orion_client: OrionClient, deployment: ORMDeployment, mock_infrastructure_kill: AsyncMock, caplog, + cancelling_constructor, ): flow_run = await orion_client.create_flow_run_from_deployment( deployment.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) await orion_client.update_flow_run(flow_run.id, infrastructure_pid="test") mock_infrastructure_kill.side_effect = ValueError("Oh no!") @@ -402,8 +453,15 @@ async def test_agent_cancel_run_with_unknown_error_during_kill( assert "Traceback" in caplog.text +@pytest.mark.parametrize( + "cancelling_constructor", [legacy_named_cancelling_state, Cancelling] +) async def test_agent_cancel_run_without_infrastructure_support_for_kill( - orion_client: OrionClient, deployment: ORMDeployment, caplog, monkeypatch + orion_client: OrionClient, + deployment: ORMDeployment, + caplog, + monkeypatch, + cancelling_constructor, ): # Patch all infrastructure types @@ -415,7 +473,7 @@ async def test_agent_cancel_run_without_infrastructure_support_for_kill( flow_run = await orion_client.create_flow_run_from_deployment( deployment.id, - state=Cancelled(name="Cancelling"), + state=cancelling_constructor(), ) await orion_client.update_flow_run(flow_run.id, infrastructure_pid="test") diff --git a/tests/cli/test_flow_run.py b/tests/cli/test_flow_run.py index 7b8f0311e373..a26100316b77 100644 --- a/tests/cli/test_flow_run.py +++ b/tests/cli/test_flow_run.py @@ -245,8 +245,7 @@ async def test_non_terminal_states_set_to_cancelled(self, orion_client, state): after = await orion_client.read_flow_run(before.id) assert before.state.name != after.state.name assert before.state.type != after.state.type - assert after.state.name == "Cancelling" - assert after.state.type == StateType.CANCELLED + assert after.state.type == StateType.CANCELLING @pytest.mark.parametrize("state", [Completed, Failed, Crashed, Cancelled]) async def test_cancelling_terminal_states_exits_with_error(