From ecaea254b155d3a3f32557d94ef1f41fdbf22ebb Mon Sep 17 00:00:00 2001 From: Chris Pickett Date: Thu, 15 Dec 2022 13:03:09 -0500 Subject: [PATCH] Create a `CANCELLING` state type (#7794) Co-authored-by: Zach Angell <42625717+zangell44@users.noreply.github.com> --- .../database/migrations/MIGRATION-NOTES.md | 4 + ...ee18b_add_cancelling_to_state_type_enum.py | 22 ++++ ...58a19_rename_worker_pools_to_work_pools.py | 4 +- .../orion/database/query_components.py | 2 +- .../orion/orchestration/core_policy.py | 26 +++- src/prefect/orion/schemas/states.py | 10 ++ tests/orion/models/test_work_queues.py | 24 +++- tests/orion/orchestration/test_core_policy.py | 121 +++++++++++++++++- 8 files changed, 199 insertions(+), 14 deletions(-) create mode 100644 src/prefect/orion/database/migrations/versions/postgresql/2022_12_06_164028_9326a6aee18b_add_cancelling_to_state_type_enum.py diff --git a/src/prefect/orion/database/migrations/MIGRATION-NOTES.md b/src/prefect/orion/database/migrations/MIGRATION-NOTES.md index 44fe39756ac1..5d8ac0fedc5a 100644 --- a/src/prefect/orion/database/migrations/MIGRATION-NOTES.md +++ b/src/prefect/orion/database/migrations/MIGRATION-NOTES.md @@ -8,6 +8,10 @@ Each time a database migration is written, an entry is included here with: This gives us a history of changes and will create merge conflicts if two migrations are made at once, flagging situations where a branch needs to be updated before merging. +# Add `CANCELLING` to StateType enum +SQLite: None +Postgres: `9326a6aee18b` + # Add infrastructure_pid to flow runs SQLite: `7201de756d85` Postgres: `5d526270ddb4` diff --git a/src/prefect/orion/database/migrations/versions/postgresql/2022_12_06_164028_9326a6aee18b_add_cancelling_to_state_type_enum.py b/src/prefect/orion/database/migrations/versions/postgresql/2022_12_06_164028_9326a6aee18b_add_cancelling_to_state_type_enum.py new file mode 100644 index 000000000000..756b0c564894 --- /dev/null +++ b/src/prefect/orion/database/migrations/versions/postgresql/2022_12_06_164028_9326a6aee18b_add_cancelling_to_state_type_enum.py @@ -0,0 +1,22 @@ +"""Add CANCELLING to state type enum + +Revision ID: 9326a6aee18b +Revises: f7587d6c5776 +Create Date: 2022-12-06 16:40:28.282753 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "9326a6aee18b" +down_revision = "f7587d6c5776" +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute("ALTER TYPE state_type ADD VALUE IF NOT EXISTS 'CANCELLING';") + + +def downgrade(): + pass diff --git a/src/prefect/orion/database/migrations/versions/postgresql/2023_01_08_180142_d481d5058a19_rename_worker_pools_to_work_pools.py b/src/prefect/orion/database/migrations/versions/postgresql/2023_01_08_180142_d481d5058a19_rename_worker_pools_to_work_pools.py index 6db6ca28cf91..ca7845294309 100644 --- a/src/prefect/orion/database/migrations/versions/postgresql/2023_01_08_180142_d481d5058a19_rename_worker_pools_to_work_pools.py +++ b/src/prefect/orion/database/migrations/versions/postgresql/2023_01_08_180142_d481d5058a19_rename_worker_pools_to_work_pools.py @@ -1,7 +1,7 @@ """Rename Worker Pools to Work Pools Revision ID: d481d5058a19 -Revises: f7587d6c5776 +Revises: 9326a6aee18b Create Date: 2023-01-08 18:01:42.559990 """ @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. revision = "d481d5058a19" -down_revision = "f7587d6c5776" +down_revision = "9326a6aee18b" branch_labels = None depends_on = None diff --git a/src/prefect/orion/database/query_components.py b/src/prefect/orion/database/query_components.py index b569dc080986..0615abce3540 100644 --- a/src/prefect/orion/database/query_components.py +++ b/src/prefect/orion/database/query_components.py @@ -174,7 +174,7 @@ def get_scheduled_flow_runs_from_work_queues( db.FlowRun, sa.and_( self._flow_run_work_queue_join_clause(db.FlowRun, db.WorkQueue), - db.FlowRun.state_type.in_(["RUNNING", "PENDING"]), + db.FlowRun.state_type.in_(["RUNNING", "PENDING", "CANCELLING"]), ), isouter=True, ) diff --git a/src/prefect/orion/orchestration/core_policy.py b/src/prefect/orion/orchestration/core_policy.py index ca086bfe77ad..2fce16f0b32e 100644 --- a/src/prefect/orion/orchestration/core_policy.py +++ b/src/prefect/orion/orchestration/core_policy.py @@ -163,7 +163,8 @@ async def cleanup( class ReleaseTaskConcurrencySlots(BaseUniversalTransform): """ - Releases any concurrency slots held by a run upon exiting a Running state. + Releases any concurrency slots held by a run upon exiting a Running or + Cancelling state. """ async def after_transition( @@ -173,7 +174,10 @@ async def after_transition( if self.nullified_transition(): return - if not context.validated_state.is_running(): + if context.validated_state.type not in [ + states.StateType.RUNNING, + states.StateType.CANCELLING, + ]: filtered_limits = ( await concurrency_limits.filter_concurrency_limits_for_orchestration( context.session, tags=context.run.tags @@ -714,10 +718,23 @@ class PreventRedundantTransitions(BaseOrchestrationRule): StateType.SCHEDULED: 1, StateType.PENDING: 2, StateType.RUNNING: 3, + StateType.CANCELLING: 4, } - FROM_STATES = [StateType.SCHEDULED, StateType.PENDING, StateType.RUNNING, None] - TO_STATES = [StateType.SCHEDULED, StateType.PENDING, StateType.RUNNING, None] + FROM_STATES = [ + StateType.SCHEDULED, + StateType.PENDING, + StateType.RUNNING, + StateType.CANCELLING, + None, + ] + TO_STATES = [ + StateType.SCHEDULED, + StateType.PENDING, + StateType.RUNNING, + StateType.CANCELLING, + None, + ] async def before_transition( self, @@ -727,6 +744,7 @@ async def before_transition( ) -> None: initial_state_type = initial_state.type if initial_state else None proposed_state_type = proposed_state.type if proposed_state else None + if ( self.STATE_PROGRESS[proposed_state_type] <= self.STATE_PROGRESS[initial_state_type] diff --git a/src/prefect/orion/schemas/states.py b/src/prefect/orion/schemas/states.py index 485da941226a..6ffa68e8fc75 100644 --- a/src/prefect/orion/schemas/states.py +++ b/src/prefect/orion/schemas/states.py @@ -27,6 +27,7 @@ class StateType(AutoEnum): CANCELLED = AutoEnum.auto() CRASHED = AutoEnum.auto() PAUSED = AutoEnum.auto() + CANCELLING = AutoEnum.auto() TERMINAL_STATES = { @@ -268,6 +269,15 @@ def Crashed(cls: Type[State] = State, **kwargs) -> State: return cls(type=StateType.CRASHED, **kwargs) +def Cancelling(cls: Type[State] = State, **kwargs) -> State: + """Convenience function for creating `Cancelling` states. + + Returns: + State: a Cancelling state + """ + return cls(type=StateType.CANCELLING, **kwargs) + + def Cancelled(cls: Type[State] = State, **kwargs) -> State: """Convenience function for creating `Cancelled` states. diff --git a/tests/orion/models/test_work_queues.py b/tests/orion/models/test_work_queues.py index 3ec04dc3c307..765a0f53be01 100644 --- a/tests/orion/models/test_work_queues.py +++ b/tests/orion/models/test_work_queues.py @@ -238,6 +238,12 @@ async def test_delete_work_queue_returns_false_if_does_not_exist(self, session): class TestGetRunsInWorkQueue: + running_flow_states = [ + schemas.states.StateType.PENDING, + schemas.states.StateType.CANCELLING, + schemas.states.StateType.RUNNING, + ] + @pytest.fixture async def work_queue_2(self, session): work_queue = await models.work_queues.create_work_queue( @@ -270,7 +276,7 @@ async def scheduled_flow_runs(self, session, deployment, work_queue, work_queue_ @pytest.fixture async def running_flow_runs(self, session, deployment, work_queue, work_queue_2): - for i in range(3): + for state_type in self.running_flow_states: for wq in [work_queue, work_queue_2]: await models.flow_runs.create_flow_run( session=session, @@ -279,7 +285,7 @@ async def running_flow_runs(self, session, deployment, work_queue, work_queue_2) deployment_id=deployment.id, work_queue_name=wq.name, state=schemas.states.State( - type="RUNNING" if i == 0 else "PENDING", + type=state_type, timestamp=pendulum.now("UTC").subtract(seconds=10), ), ), @@ -368,7 +374,9 @@ async def test_get_runs_in_queue_concurrency_limit( session=session, work_queue_id=work_queue.id ) - assert len(runs_wq1) == max(0, min(3, concurrency_limit - 3)) + assert len(runs_wq1) == max( + 0, min(3, concurrency_limit - len(self.running_flow_states)) + ) @pytest.mark.parametrize("limit", [10, 1]) async def test_get_runs_in_queue_concurrency_limit_and_limit( @@ -379,14 +387,20 @@ async def test_get_runs_in_queue_concurrency_limit_and_limit( running_flow_runs, limit, ): + concurrency_limit = 5 + await models.work_queues.update_work_queue( session=session, work_queue_id=work_queue.id, - work_queue=schemas.actions.WorkQueueUpdate(concurrency_limit=5), + work_queue=schemas.actions.WorkQueueUpdate( + concurrency_limit=concurrency_limit + ), ) runs_wq1 = await models.work_queues.get_runs_in_work_queue( session=session, work_queue_id=work_queue.id, limit=limit ) - assert len(runs_wq1) == min(limit, 2) + assert len(runs_wq1) == min( + limit, concurrency_limit - len(self.running_flow_states) + ) diff --git a/tests/orion/orchestration/test_core_policy.py b/tests/orion/orchestration/test_core_policy.py index 4bc9096df4dc..33db69cb1db6 100644 --- a/tests/orion/orchestration/test_core_policy.py +++ b/tests/orion/orchestration/test_core_policy.py @@ -1121,6 +1121,7 @@ async def test_all_other_transitions_are_accepted( @pytest.mark.parametrize("run_type", ["task", "flow"]) class TestPreventingRedundantTransitionsRule: active_states = ( + states.StateType.CANCELLING, states.StateType.RUNNING, states.StateType.PENDING, states.StateType.SCHEDULED, @@ -1300,6 +1301,113 @@ async def test_basic_concurrency_limiting( assert task2_run_retry_ctx.response_status == SetStateStatus.ACCEPT assert (await self.count_concurrency_slots(session, "some tag")) == 1 + async def test_concurrency_limit_cancelling_transition( + self, + session, + run_type, + initialize_orchestration, + ): + await self.create_concurrency_limit(session, "some tag", 1) + concurrency_policy = [SecureTaskConcurrencySlots, ReleaseTaskConcurrencySlots] + running_transition = (states.StateType.PENDING, states.StateType.RUNNING) + cancelling_transition = (states.StateType.RUNNING, states.StateType.CANCELLING) + cancelled_transition = (states.StateType.CANCELLING, states.StateType.CANCELLED) + + # before any runs, no active concurrency slots are in use + assert (await self.count_concurrency_slots(session, "some tag")) == 0 + + task1_running_ctx = await initialize_orchestration( + session, "task", *running_transition, run_tags=["some tag"] + ) + + async with contextlib.AsyncExitStack() as stack: + for rule in concurrency_policy: + task1_running_ctx = await stack.enter_async_context( + rule(task1_running_ctx, *running_transition) + ) + await task1_running_ctx.validate_proposed_state() + + # a first task run against a concurrency limited tag will be accepted + assert task1_running_ctx.response_status == SetStateStatus.ACCEPT + + task2_running_ctx = await initialize_orchestration( + session, "task", *running_transition, run_tags=["some tag"] + ) + + async with contextlib.AsyncExitStack() as stack: + for rule in concurrency_policy: + task2_running_ctx = await stack.enter_async_context( + rule(task2_running_ctx, *running_transition) + ) + await task2_running_ctx.validate_proposed_state() + + # the first task hasn't completed, so the concurrently running second task is + # told to wait + assert task2_running_ctx.response_status == SetStateStatus.WAIT + + # the number of slots occupied by active runs is equal to the concurrency limit + assert (await self.count_concurrency_slots(session, "some tag")) == 1 + + task1_cancelling_ctx = await initialize_orchestration( + session, + "task", + *cancelling_transition, + run_override=task1_running_ctx.run, + run_tags=["some tag"], + ) + + async with contextlib.AsyncExitStack() as stack: + for rule in concurrency_policy: + task1_cancelling_ctx = await stack.enter_async_context( + rule(task1_cancelling_ctx, *cancelling_transition) + ) + await task1_cancelling_ctx.validate_proposed_state() + + # the first task run will transition into a cancelling state, but + # maintain a hold on the concurrency slot + assert task1_cancelling_ctx.response_status == SetStateStatus.ACCEPT + assert (await self.count_concurrency_slots(session, "some tag")) == 1 + + task1_cancelled_ctx = await initialize_orchestration( + session, + "task", + *cancelled_transition, + run_override=task1_running_ctx.run, + run_tags=["some tag"], + ) + + async with contextlib.AsyncExitStack() as stack: + for rule in concurrency_policy: + task1_cancelled_ctx = await stack.enter_async_context( + rule(task1_cancelled_ctx, *cancelled_transition) + ) + await task1_cancelled_ctx.validate_proposed_state() + + # the first task run will transition into a cancelled state, yielding a + # concurrency slot + assert task1_cancelled_ctx.response_status == SetStateStatus.ACCEPT + assert (await self.count_concurrency_slots(session, "some tag")) == 0 + + # the second task tries to run again, this time the transition will be accepted + # now that a concurrency slot has been freed + task2_run_retry_ctx = await initialize_orchestration( + session, + "task", + *running_transition, + run_override=task2_running_ctx.run, + run_tags=["some tag"], + ) + + async with contextlib.AsyncExitStack() as stack: + for rule in concurrency_policy: + task2_run_retry_ctx = await stack.enter_async_context( + rule(task2_run_retry_ctx, *running_transition) + ) + await task2_run_retry_ctx.validate_proposed_state() + + assert task2_run_retry_ctx.response_status == SetStateStatus.ACCEPT + assert (await self.count_concurrency_slots(session, "some tag")) == 1 + async def test_concurrency_limiting_aborts_transitions_with_zero_limit( self, session, @@ -1728,7 +1836,12 @@ async def before_transition(self, initial_state, proposed_state, context): mutated_state.type = random.choice( list( set(states.StateType) - - {initial_state.type, proposed_state.type} + - { + initial_state.type, + proposed_state.type, + states.StateType.RUNNING, + states.StateType.CANCELLING, + } ) ) await self.reject_transition( @@ -1812,7 +1925,11 @@ async def before_transition(self, initial_state, proposed_state, context): mutated_state.type = random.choice( list( set(states.StateType) - - {initial_state.type, proposed_state.type} + - { + initial_state.type, + proposed_state.type, + states.StateType.CANCELLING, + } ) ) await self.reject_transition(