From ff1ceabb1c21c28ef41ace445f51259e6fee9dd5 Mon Sep 17 00:00:00 2001 From: Chris Pickett Date: Tue, 6 Dec 2022 11:59:03 -0500 Subject: [PATCH 1/9] Create `CANCELLING` state --- src/prefect/orion/schemas/states.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/prefect/orion/schemas/states.py b/src/prefect/orion/schemas/states.py index d5378563db37..35d53cc2a4e5 100644 --- a/src/prefect/orion/schemas/states.py +++ b/src/prefect/orion/schemas/states.py @@ -27,6 +27,7 @@ class StateType(AutoEnum): CANCELLED = AutoEnum.auto() CRASHED = AutoEnum.auto() PAUSED = AutoEnum.auto() + CANCELLING = AutoEnum.auto() TERMINAL_STATES = { @@ -267,6 +268,15 @@ def Crashed(cls: Type[State] = State, **kwargs) -> State: return cls(type=StateType.CRASHED, **kwargs) +def Cancelling(cls: Type[State] = State, **kwargs) -> State: + """Convenience function for creating `Cancelling` states. + + Returns: + State: a Cancelling state + """ + return cls(type=StateType.CANCELLING, **kwargs) + + def Cancelled(cls: Type[State] = State, **kwargs) -> State: """Convenience function for creating `Cancelled` states. From 105f60b7686b07787cd6ed397380ceb9fd3a0784 Mon Sep 17 00:00:00 2001 From: Chris Pickett Date: Tue, 6 Dec 2022 13:02:57 -0500 Subject: [PATCH 2/9] Count `CANCELLING` flow runs against work queue concurrency limits --- .../orion/database/query_components.py | 2 +- tests/orion/models/test_work_queues.py | 25 +++++++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/prefect/orion/database/query_components.py b/src/prefect/orion/database/query_components.py index 46cde060f307..8adfbaa6bb9e 100644 --- a/src/prefect/orion/database/query_components.py +++ b/src/prefect/orion/database/query_components.py @@ -166,7 +166,7 @@ def get_scheduled_flow_runs_from_work_queues( db.FlowRun, sa.and_( self._flow_run_work_queue_join_clause(db.FlowRun, db.WorkQueue), - db.FlowRun.state_type.in_(["RUNNING", "PENDING"]), + db.FlowRun.state_type.in_(["RUNNING", "PENDING", "CANCELLING"]), ), isouter=True, ) diff --git a/tests/orion/models/test_work_queues.py b/tests/orion/models/test_work_queues.py index 3ec04dc3c307..875e2fcb6b8d 100644 --- a/tests/orion/models/test_work_queues.py +++ b/tests/orion/models/test_work_queues.py @@ -238,6 +238,8 @@ async def test_delete_work_queue_returns_false_if_does_not_exist(self, session): class TestGetRunsInWorkQueue: + running_flow_count = 4 + @pytest.fixture async def work_queue_2(self, session): work_queue = await models.work_queues.create_work_queue( @@ -270,8 +272,15 @@ async def scheduled_flow_runs(self, session, deployment, work_queue, work_queue_ @pytest.fixture async def running_flow_runs(self, session, deployment, work_queue, work_queue_2): - for i in range(3): + for i in range(self.running_flow_count): for wq in [work_queue, work_queue_2]: + if i == 0: + state_type = "PENDING" + elif i == 1: + state_type = "CANCELLING" + else: + state_type = "RUNNING" + await models.flow_runs.create_flow_run( session=session, flow_run=schemas.core.FlowRun( @@ -279,7 +288,7 @@ async def running_flow_runs(self, session, deployment, work_queue, work_queue_2) deployment_id=deployment.id, work_queue_name=wq.name, state=schemas.states.State( - type="RUNNING" if i == 0 else "PENDING", + type=state_type, timestamp=pendulum.now("UTC").subtract(seconds=10), ), ), @@ -368,7 +377,9 @@ async def test_get_runs_in_queue_concurrency_limit( session=session, work_queue_id=work_queue.id ) - assert len(runs_wq1) == max(0, min(3, concurrency_limit - 3)) + assert len(runs_wq1) == max( + 0, min(3, concurrency_limit - self.running_flow_count) + ) @pytest.mark.parametrize("limit", [10, 1]) async def test_get_runs_in_queue_concurrency_limit_and_limit( @@ -379,14 +390,18 @@ async def test_get_runs_in_queue_concurrency_limit_and_limit( running_flow_runs, limit, ): + concurrency_limit = 5 + await models.work_queues.update_work_queue( session=session, work_queue_id=work_queue.id, - work_queue=schemas.actions.WorkQueueUpdate(concurrency_limit=5), + work_queue=schemas.actions.WorkQueueUpdate( + concurrency_limit=concurrency_limit + ), ) runs_wq1 = await models.work_queues.get_runs_in_work_queue( session=session, work_queue_id=work_queue.id, limit=limit ) - assert len(runs_wq1) == min(limit, 2) + assert len(runs_wq1) == min(limit, concurrency_limit - self.running_flow_count) From ec2ba99becd6b22623fd5697f54f9007fe7f2488 Mon Sep 17 00:00:00 2001 From: Chris Pickett Date: Tue, 6 Dec 2022 15:54:33 -0500 Subject: [PATCH 3/9] Maintain concurrency slots during transition from RUNNING -> CANCELLING --- .../orion/orchestration/core_policy.py | 8 +- tests/orion/orchestration/test_core_policy.py | 107 ++++++++++++++++++ 2 files changed, 113 insertions(+), 2 deletions(-) diff --git a/src/prefect/orion/orchestration/core_policy.py b/src/prefect/orion/orchestration/core_policy.py index 170aff138046..eb110d328cc1 100644 --- a/src/prefect/orion/orchestration/core_policy.py +++ b/src/prefect/orion/orchestration/core_policy.py @@ -162,7 +162,8 @@ async def cleanup( class ReleaseTaskConcurrencySlots(BaseUniversalTransform): """ - Releases any concurrency slots held by a run upon exiting a Running state. + Releases any concurrency slots held by a run upon exiting a Running or + Cancelling state. """ async def after_transition( @@ -172,7 +173,10 @@ async def after_transition( if self.nullified_transition(): return - if not context.validated_state.is_running(): + if context.validated_state.type not in [ + states.StateType.RUNNING, + states.StateType.CANCELLING, + ]: filtered_limits = ( await concurrency_limits.filter_concurrency_limits_for_orchestration( context.session, tags=context.run.tags diff --git a/tests/orion/orchestration/test_core_policy.py b/tests/orion/orchestration/test_core_policy.py index d83ce39029c9..e8d0277a3172 100644 --- a/tests/orion/orchestration/test_core_policy.py +++ b/tests/orion/orchestration/test_core_policy.py @@ -1147,6 +1147,113 @@ async def test_basic_concurrency_limiting( assert task2_run_retry_ctx.response_status == SetStateStatus.ACCEPT assert (await self.count_concurrency_slots(session, "some tag")) == 1 + async def test_concurrency_limit_cancelling_transition( + self, + session, + run_type, + initialize_orchestration, + ): + await self.create_concurrency_limit(session, "some tag", 1) + concurrency_policy = [SecureTaskConcurrencySlots, ReleaseTaskConcurrencySlots] + running_transition = (states.StateType.PENDING, states.StateType.RUNNING) + cancelling_transition = (states.StateType.RUNNING, states.StateType.CANCELLING) + cancelled_transition = (states.StateType.CANCELLING, states.StateType.CANCELLED) + + # before any runs, no active concurrency slots are in use + assert (await self.count_concurrency_slots(session, "some tag")) == 0 + + task1_running_ctx = await initialize_orchestration( + session, "task", *running_transition, run_tags=["some tag"] + ) + + async with contextlib.AsyncExitStack() as stack: + for rule in concurrency_policy: + task1_running_ctx = await stack.enter_async_context( + rule(task1_running_ctx, *running_transition) + ) + await task1_running_ctx.validate_proposed_state() + + # a first task run against a concurrency limited tag will be accepted + assert task1_running_ctx.response_status == SetStateStatus.ACCEPT + + task2_running_ctx = await initialize_orchestration( + session, "task", *running_transition, run_tags=["some tag"] + ) + + async with contextlib.AsyncExitStack() as stack: + for rule in concurrency_policy: + task2_running_ctx = await stack.enter_async_context( + rule(task2_running_ctx, *running_transition) + ) + await task2_running_ctx.validate_proposed_state() + + # the first task hasn't completed, so the concurrently running second task is + # told to wait + assert task2_running_ctx.response_status == SetStateStatus.WAIT + + # the number of slots occupied by active runs is equal to the concurrency limit + assert (await self.count_concurrency_slots(session, "some tag")) == 1 + + task1_cancelling_ctx = await initialize_orchestration( + session, + "task", + *cancelling_transition, + run_override=task1_running_ctx.run, + run_tags=["some tag"], + ) + + async with contextlib.AsyncExitStack() as stack: + for rule in concurrency_policy: + task1_cancelling_ctx = await stack.enter_async_context( + rule(task1_cancelling_ctx, *cancelling_transition) + ) + await task1_cancelling_ctx.validate_proposed_state() + + # the first task run will transition into a cancelling state, but + # maintain a hold on the concurrency slot + assert task1_cancelling_ctx.response_status == SetStateStatus.ACCEPT + assert (await self.count_concurrency_slots(session, "some tag")) == 1 + + task1_cancelled_ctx = await initialize_orchestration( + session, + "task", + *cancelled_transition, + run_override=task1_running_ctx.run, + run_tags=["some tag"], + ) + + async with contextlib.AsyncExitStack() as stack: + for rule in concurrency_policy: + task1_cancelled_ctx = await stack.enter_async_context( + rule(task1_cancelled_ctx, *cancelled_transition) + ) + await task1_cancelled_ctx.validate_proposed_state() + + # the first task run will transition into a cancelled state, yielding a + # concurrency slot + assert task1_cancelled_ctx.response_status == SetStateStatus.ACCEPT + assert (await self.count_concurrency_slots(session, "some tag")) == 0 + + # the second task tries to run again, this time the transition will be accepted + # now that a concurrency slot has been freed + task2_run_retry_ctx = await initialize_orchestration( + session, + "task", + *running_transition, + run_override=task2_running_ctx.run, + run_tags=["some tag"], + ) + + async with contextlib.AsyncExitStack() as stack: + for rule in concurrency_policy: + task2_run_retry_ctx = await stack.enter_async_context( + rule(task2_run_retry_ctx, *running_transition) + ) + await task2_run_retry_ctx.validate_proposed_state() + + assert task2_run_retry_ctx.response_status == SetStateStatus.ACCEPT + assert (await self.count_concurrency_slots(session, "some tag")) == 1 + async def test_concurrency_limiting_aborts_transitions_with_zero_limit( self, session, From bf23bddfd13524a9f0e88798a620a3518bf2ac6b Mon Sep 17 00:00:00 2001 From: Chris Pickett Date: Tue, 6 Dec 2022 16:21:30 -0500 Subject: [PATCH 4/9] Create migration to add `CANCELLING` to StateType enum --- .../database/migrations/MIGRATION-NOTES.md | 4 ++++ ...ee18b_add_cancelling_to_state_type_enum.py | 22 +++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 src/prefect/orion/database/migrations/versions/postgresql/2022_12_06_164028_9326a6aee18b_add_cancelling_to_state_type_enum.py diff --git a/src/prefect/orion/database/migrations/MIGRATION-NOTES.md b/src/prefect/orion/database/migrations/MIGRATION-NOTES.md index 44fe39756ac1..5d8ac0fedc5a 100644 --- a/src/prefect/orion/database/migrations/MIGRATION-NOTES.md +++ b/src/prefect/orion/database/migrations/MIGRATION-NOTES.md @@ -8,6 +8,10 @@ Each time a database migration is written, an entry is included here with: This gives us a history of changes and will create merge conflicts if two migrations are made at once, flagging situations where a branch needs to be updated before merging. +# Add `CANCELLING` to StateType enum +SQLite: None +Postgres: `9326a6aee18b` + # Add infrastructure_pid to flow runs SQLite: `7201de756d85` Postgres: `5d526270ddb4` diff --git a/src/prefect/orion/database/migrations/versions/postgresql/2022_12_06_164028_9326a6aee18b_add_cancelling_to_state_type_enum.py b/src/prefect/orion/database/migrations/versions/postgresql/2022_12_06_164028_9326a6aee18b_add_cancelling_to_state_type_enum.py new file mode 100644 index 000000000000..b485b321de35 --- /dev/null +++ b/src/prefect/orion/database/migrations/versions/postgresql/2022_12_06_164028_9326a6aee18b_add_cancelling_to_state_type_enum.py @@ -0,0 +1,22 @@ +"""Add CANCELLING to state type enum + +Revision ID: 9326a6aee18b +Revises: 5e4f924ff96c +Create Date: 2022-12-06 16:40:28.282753 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "9326a6aee18b" +down_revision = "5e4f924ff96c" +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute("ALTER TYPE state_type ADD VALUE IF NOT EXISTS 'CANCELLING';") + + +def downgrade(): + pass From 1e838f69af5e2ee6e5b1f7b2dd8248fc3897ff2c Mon Sep 17 00:00:00 2001 From: Chris Pickett Date: Wed, 7 Dec 2022 11:56:11 -0500 Subject: [PATCH 5/9] Make work queues test a bit more readable --- tests/orion/models/test_work_queues.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/orion/models/test_work_queues.py b/tests/orion/models/test_work_queues.py index 875e2fcb6b8d..d3c2ea847da2 100644 --- a/tests/orion/models/test_work_queues.py +++ b/tests/orion/models/test_work_queues.py @@ -238,7 +238,12 @@ async def test_delete_work_queue_returns_false_if_does_not_exist(self, session): class TestGetRunsInWorkQueue: - running_flow_count = 4 + running_flow_states = [ + schemas.states.StateType.PENDING, + schemas.states.StateType.CANCELLING, + schemas.states.StateType.RUNNING, + schemas.states.StateType.RUNNING, + ] @pytest.fixture async def work_queue_2(self, session): @@ -272,15 +277,8 @@ async def scheduled_flow_runs(self, session, deployment, work_queue, work_queue_ @pytest.fixture async def running_flow_runs(self, session, deployment, work_queue, work_queue_2): - for i in range(self.running_flow_count): + for state_type in self.running_flow_states: for wq in [work_queue, work_queue_2]: - if i == 0: - state_type = "PENDING" - elif i == 1: - state_type = "CANCELLING" - else: - state_type = "RUNNING" - await models.flow_runs.create_flow_run( session=session, flow_run=schemas.core.FlowRun( @@ -378,7 +376,7 @@ async def test_get_runs_in_queue_concurrency_limit( ) assert len(runs_wq1) == max( - 0, min(3, concurrency_limit - self.running_flow_count) + 0, min(3, concurrency_limit - len(self.running_flow_states)) ) @pytest.mark.parametrize("limit", [10, 1]) @@ -404,4 +402,6 @@ async def test_get_runs_in_queue_concurrency_limit_and_limit( session=session, work_queue_id=work_queue.id, limit=limit ) - assert len(runs_wq1) == min(limit, concurrency_limit - self.running_flow_count) + assert len(runs_wq1) == min( + limit, concurrency_limit - len(self.running_flow_states) + ) From 5ed9a91c76d3115be29e73ba50bc60819d2189a8 Mon Sep 17 00:00:00 2001 From: Chris Pickett Date: Mon, 12 Dec 2022 11:29:40 -0500 Subject: [PATCH 6/9] Fix fizzle test --- tests/orion/orchestration/test_core_policy.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/orion/orchestration/test_core_policy.py b/tests/orion/orchestration/test_core_policy.py index e8d0277a3172..eca7cfacd939 100644 --- a/tests/orion/orchestration/test_core_policy.py +++ b/tests/orion/orchestration/test_core_policy.py @@ -1682,7 +1682,12 @@ async def before_transition(self, initial_state, proposed_state, context): mutated_state.type = random.choice( list( set(states.StateType) - - {initial_state.type, proposed_state.type} + - { + initial_state.type, + proposed_state.type, + states.StateType.RUNNING, + states.StateType.CANCELLING, + } ) ) await self.reject_transition( From f9906bbd42d33457219b5f3c8a24e0e7909ab11a Mon Sep 17 00:00:00 2001 From: Chris Pickett Date: Mon, 12 Dec 2022 11:58:17 -0500 Subject: [PATCH 7/9] Add CANCELLING handling to `PreventRedundantTransitions` --- src/prefect/orion/orchestration/core_policy.py | 18 ++++++++++++++++-- tests/orion/orchestration/test_core_policy.py | 1 + 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/prefect/orion/orchestration/core_policy.py b/src/prefect/orion/orchestration/core_policy.py index eb110d328cc1..33088b0d849e 100644 --- a/src/prefect/orion/orchestration/core_policy.py +++ b/src/prefect/orion/orchestration/core_policy.py @@ -702,10 +702,23 @@ class PreventRedundantTransitions(BaseOrchestrationRule): StateType.SCHEDULED: 1, StateType.PENDING: 2, StateType.RUNNING: 3, + StateType.CANCELLING: 4, } - FROM_STATES = [StateType.SCHEDULED, StateType.PENDING, StateType.RUNNING, None] - TO_STATES = [StateType.SCHEDULED, StateType.PENDING, StateType.RUNNING, None] + FROM_STATES = [ + StateType.SCHEDULED, + StateType.PENDING, + StateType.RUNNING, + StateType.CANCELLING, + None, + ] + TO_STATES = [ + StateType.SCHEDULED, + StateType.PENDING, + StateType.RUNNING, + StateType.CANCELLING, + None, + ] async def before_transition( self, @@ -715,6 +728,7 @@ async def before_transition( ) -> None: initial_state_type = initial_state.type if initial_state else None proposed_state_type = proposed_state.type if proposed_state else None + if ( self.STATE_PROGRESS[proposed_state_type] <= self.STATE_PROGRESS[initial_state_type] diff --git a/tests/orion/orchestration/test_core_policy.py b/tests/orion/orchestration/test_core_policy.py index eca7cfacd939..a340901981bc 100644 --- a/tests/orion/orchestration/test_core_policy.py +++ b/tests/orion/orchestration/test_core_policy.py @@ -968,6 +968,7 @@ async def test_all_other_transitions_are_accepted( @pytest.mark.parametrize("run_type", ["task", "flow"]) class TestPreventingRedundantTransitionsRule: active_states = ( + states.StateType.CANCELLING, states.StateType.RUNNING, states.StateType.PENDING, states.StateType.SCHEDULED, From e95ad9337f5b80d84c21163474b85d442860bd8a Mon Sep 17 00:00:00 2001 From: Chris Pickett Date: Mon, 12 Dec 2022 13:34:59 -0500 Subject: [PATCH 8/9] Update tests/orion/models/test_work_queues.py Co-authored-by: Zach Angell <42625717+zangell44@users.noreply.github.com> --- tests/orion/models/test_work_queues.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/orion/models/test_work_queues.py b/tests/orion/models/test_work_queues.py index d3c2ea847da2..765a0f53be01 100644 --- a/tests/orion/models/test_work_queues.py +++ b/tests/orion/models/test_work_queues.py @@ -242,7 +242,6 @@ class TestGetRunsInWorkQueue: schemas.states.StateType.PENDING, schemas.states.StateType.CANCELLING, schemas.states.StateType.RUNNING, - schemas.states.StateType.RUNNING, ] @pytest.fixture From 880e2c75ef6e1fd7364cb6e69fc832da348b6dc7 Mon Sep 17 00:00:00 2001 From: Chris Pickett Date: Wed, 14 Dec 2022 15:27:15 -0500 Subject: [PATCH 9/9] Fix heisentest in test_core_policy --- tests/orion/orchestration/test_core_policy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/orion/orchestration/test_core_policy.py b/tests/orion/orchestration/test_core_policy.py index a340901981bc..c4c3eb7599db 100644 --- a/tests/orion/orchestration/test_core_policy.py +++ b/tests/orion/orchestration/test_core_policy.py @@ -1772,7 +1772,11 @@ async def before_transition(self, initial_state, proposed_state, context): mutated_state.type = random.choice( list( set(states.StateType) - - {initial_state.type, proposed_state.type} + - { + initial_state.type, + proposed_state.type, + states.StateType.CANCELLING, + } ) ) await self.reject_transition(