Skip to content

Commit

Permalink
Create a CANCELLING state type (#7794)
Browse files Browse the repository at this point in the history
Co-authored-by: Zach Angell <[email protected]>
  • Loading branch information
2 people authored and zanieb committed Jan 25, 2023
1 parent ee49d48 commit ecaea25
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 14 deletions.
4 changes: 4 additions & 0 deletions src/prefect/orion/database/migrations/MIGRATION-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Each time a database migration is written, an entry is included here with:

This gives us a history of changes and will create merge conflicts if two migrations are made at once, flagging situations where a branch needs to be updated before merging.

# Add `CANCELLING` to StateType enum
SQLite: None
Postgres: `9326a6aee18b`

# Add infrastructure_pid to flow runs
SQLite: `7201de756d85`
Postgres: `5d526270ddb4`
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Add CANCELLING to state type enum
Revision ID: 9326a6aee18b
Revises: f7587d6c5776
Create Date: 2022-12-06 16:40:28.282753
"""
from alembic import op

# revision identifiers, used by Alembic.
revision = "9326a6aee18b"
down_revision = "f7587d6c5776"
branch_labels = None
depends_on = None


def upgrade():
op.execute("ALTER TYPE state_type ADD VALUE IF NOT EXISTS 'CANCELLING';")


def downgrade():
pass
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Rename Worker Pools to Work Pools
Revision ID: d481d5058a19
Revises: f7587d6c5776
Revises: 9326a6aee18b
Create Date: 2023-01-08 18:01:42.559990
"""
Expand All @@ -12,7 +12,7 @@

# revision identifiers, used by Alembic.
revision = "d481d5058a19"
down_revision = "f7587d6c5776"
down_revision = "9326a6aee18b"
branch_labels = None
depends_on = None

Expand Down
2 changes: 1 addition & 1 deletion src/prefect/orion/database/query_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def get_scheduled_flow_runs_from_work_queues(
db.FlowRun,
sa.and_(
self._flow_run_work_queue_join_clause(db.FlowRun, db.WorkQueue),
db.FlowRun.state_type.in_(["RUNNING", "PENDING"]),
db.FlowRun.state_type.in_(["RUNNING", "PENDING", "CANCELLING"]),
),
isouter=True,
)
Expand Down
26 changes: 22 additions & 4 deletions src/prefect/orion/orchestration/core_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ async def cleanup(

class ReleaseTaskConcurrencySlots(BaseUniversalTransform):
"""
Releases any concurrency slots held by a run upon exiting a Running state.
Releases any concurrency slots held by a run upon exiting a Running or
Cancelling state.
"""

async def after_transition(
Expand All @@ -173,7 +174,10 @@ async def after_transition(
if self.nullified_transition():
return

if not context.validated_state.is_running():
if context.validated_state.type not in [
states.StateType.RUNNING,
states.StateType.CANCELLING,
]:
filtered_limits = (
await concurrency_limits.filter_concurrency_limits_for_orchestration(
context.session, tags=context.run.tags
Expand Down Expand Up @@ -714,10 +718,23 @@ class PreventRedundantTransitions(BaseOrchestrationRule):
StateType.SCHEDULED: 1,
StateType.PENDING: 2,
StateType.RUNNING: 3,
StateType.CANCELLING: 4,
}

FROM_STATES = [StateType.SCHEDULED, StateType.PENDING, StateType.RUNNING, None]
TO_STATES = [StateType.SCHEDULED, StateType.PENDING, StateType.RUNNING, None]
FROM_STATES = [
StateType.SCHEDULED,
StateType.PENDING,
StateType.RUNNING,
StateType.CANCELLING,
None,
]
TO_STATES = [
StateType.SCHEDULED,
StateType.PENDING,
StateType.RUNNING,
StateType.CANCELLING,
None,
]

async def before_transition(
self,
Expand All @@ -727,6 +744,7 @@ async def before_transition(
) -> None:
initial_state_type = initial_state.type if initial_state else None
proposed_state_type = proposed_state.type if proposed_state else None

if (
self.STATE_PROGRESS[proposed_state_type]
<= self.STATE_PROGRESS[initial_state_type]
Expand Down
10 changes: 10 additions & 0 deletions src/prefect/orion/schemas/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class StateType(AutoEnum):
CANCELLED = AutoEnum.auto()
CRASHED = AutoEnum.auto()
PAUSED = AutoEnum.auto()
CANCELLING = AutoEnum.auto()


TERMINAL_STATES = {
Expand Down Expand Up @@ -268,6 +269,15 @@ def Crashed(cls: Type[State] = State, **kwargs) -> State:
return cls(type=StateType.CRASHED, **kwargs)


def Cancelling(cls: Type[State] = State, **kwargs) -> State:
"""Convenience function for creating `Cancelling` states.
Returns:
State: a Cancelling state
"""
return cls(type=StateType.CANCELLING, **kwargs)


def Cancelled(cls: Type[State] = State, **kwargs) -> State:
"""Convenience function for creating `Cancelled` states.
Expand Down
24 changes: 19 additions & 5 deletions tests/orion/models/test_work_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ async def test_delete_work_queue_returns_false_if_does_not_exist(self, session):


class TestGetRunsInWorkQueue:
running_flow_states = [
schemas.states.StateType.PENDING,
schemas.states.StateType.CANCELLING,
schemas.states.StateType.RUNNING,
]

@pytest.fixture
async def work_queue_2(self, session):
work_queue = await models.work_queues.create_work_queue(
Expand Down Expand Up @@ -270,7 +276,7 @@ async def scheduled_flow_runs(self, session, deployment, work_queue, work_queue_

@pytest.fixture
async def running_flow_runs(self, session, deployment, work_queue, work_queue_2):
for i in range(3):
for state_type in self.running_flow_states:
for wq in [work_queue, work_queue_2]:
await models.flow_runs.create_flow_run(
session=session,
Expand All @@ -279,7 +285,7 @@ async def running_flow_runs(self, session, deployment, work_queue, work_queue_2)
deployment_id=deployment.id,
work_queue_name=wq.name,
state=schemas.states.State(
type="RUNNING" if i == 0 else "PENDING",
type=state_type,
timestamp=pendulum.now("UTC").subtract(seconds=10),
),
),
Expand Down Expand Up @@ -368,7 +374,9 @@ async def test_get_runs_in_queue_concurrency_limit(
session=session, work_queue_id=work_queue.id
)

assert len(runs_wq1) == max(0, min(3, concurrency_limit - 3))
assert len(runs_wq1) == max(
0, min(3, concurrency_limit - len(self.running_flow_states))
)

@pytest.mark.parametrize("limit", [10, 1])
async def test_get_runs_in_queue_concurrency_limit_and_limit(
Expand All @@ -379,14 +387,20 @@ async def test_get_runs_in_queue_concurrency_limit_and_limit(
running_flow_runs,
limit,
):
concurrency_limit = 5

await models.work_queues.update_work_queue(
session=session,
work_queue_id=work_queue.id,
work_queue=schemas.actions.WorkQueueUpdate(concurrency_limit=5),
work_queue=schemas.actions.WorkQueueUpdate(
concurrency_limit=concurrency_limit
),
)

runs_wq1 = await models.work_queues.get_runs_in_work_queue(
session=session, work_queue_id=work_queue.id, limit=limit
)

assert len(runs_wq1) == min(limit, 2)
assert len(runs_wq1) == min(
limit, concurrency_limit - len(self.running_flow_states)
)
121 changes: 119 additions & 2 deletions tests/orion/orchestration/test_core_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,7 @@ async def test_all_other_transitions_are_accepted(
@pytest.mark.parametrize("run_type", ["task", "flow"])
class TestPreventingRedundantTransitionsRule:
active_states = (
states.StateType.CANCELLING,
states.StateType.RUNNING,
states.StateType.PENDING,
states.StateType.SCHEDULED,
Expand Down Expand Up @@ -1300,6 +1301,113 @@ async def test_basic_concurrency_limiting(
assert task2_run_retry_ctx.response_status == SetStateStatus.ACCEPT
assert (await self.count_concurrency_slots(session, "some tag")) == 1

async def test_concurrency_limit_cancelling_transition(
self,
session,
run_type,
initialize_orchestration,
):
await self.create_concurrency_limit(session, "some tag", 1)
concurrency_policy = [SecureTaskConcurrencySlots, ReleaseTaskConcurrencySlots]
running_transition = (states.StateType.PENDING, states.StateType.RUNNING)
cancelling_transition = (states.StateType.RUNNING, states.StateType.CANCELLING)
cancelled_transition = (states.StateType.CANCELLING, states.StateType.CANCELLED)

# before any runs, no active concurrency slots are in use
assert (await self.count_concurrency_slots(session, "some tag")) == 0

task1_running_ctx = await initialize_orchestration(
session, "task", *running_transition, run_tags=["some tag"]
)

async with contextlib.AsyncExitStack() as stack:
for rule in concurrency_policy:
task1_running_ctx = await stack.enter_async_context(
rule(task1_running_ctx, *running_transition)
)
await task1_running_ctx.validate_proposed_state()

# a first task run against a concurrency limited tag will be accepted
assert task1_running_ctx.response_status == SetStateStatus.ACCEPT

task2_running_ctx = await initialize_orchestration(
session, "task", *running_transition, run_tags=["some tag"]
)

async with contextlib.AsyncExitStack() as stack:
for rule in concurrency_policy:
task2_running_ctx = await stack.enter_async_context(
rule(task2_running_ctx, *running_transition)
)
await task2_running_ctx.validate_proposed_state()

# the first task hasn't completed, so the concurrently running second task is
# told to wait
assert task2_running_ctx.response_status == SetStateStatus.WAIT

# the number of slots occupied by active runs is equal to the concurrency limit
assert (await self.count_concurrency_slots(session, "some tag")) == 1

task1_cancelling_ctx = await initialize_orchestration(
session,
"task",
*cancelling_transition,
run_override=task1_running_ctx.run,
run_tags=["some tag"],
)

async with contextlib.AsyncExitStack() as stack:
for rule in concurrency_policy:
task1_cancelling_ctx = await stack.enter_async_context(
rule(task1_cancelling_ctx, *cancelling_transition)
)
await task1_cancelling_ctx.validate_proposed_state()

# the first task run will transition into a cancelling state, but
# maintain a hold on the concurrency slot
assert task1_cancelling_ctx.response_status == SetStateStatus.ACCEPT
assert (await self.count_concurrency_slots(session, "some tag")) == 1

task1_cancelled_ctx = await initialize_orchestration(
session,
"task",
*cancelled_transition,
run_override=task1_running_ctx.run,
run_tags=["some tag"],
)

async with contextlib.AsyncExitStack() as stack:
for rule in concurrency_policy:
task1_cancelled_ctx = await stack.enter_async_context(
rule(task1_cancelled_ctx, *cancelled_transition)
)
await task1_cancelled_ctx.validate_proposed_state()

# the first task run will transition into a cancelled state, yielding a
# concurrency slot
assert task1_cancelled_ctx.response_status == SetStateStatus.ACCEPT
assert (await self.count_concurrency_slots(session, "some tag")) == 0

# the second task tries to run again, this time the transition will be accepted
# now that a concurrency slot has been freed
task2_run_retry_ctx = await initialize_orchestration(
session,
"task",
*running_transition,
run_override=task2_running_ctx.run,
run_tags=["some tag"],
)

async with contextlib.AsyncExitStack() as stack:
for rule in concurrency_policy:
task2_run_retry_ctx = await stack.enter_async_context(
rule(task2_run_retry_ctx, *running_transition)
)
await task2_run_retry_ctx.validate_proposed_state()

assert task2_run_retry_ctx.response_status == SetStateStatus.ACCEPT
assert (await self.count_concurrency_slots(session, "some tag")) == 1

async def test_concurrency_limiting_aborts_transitions_with_zero_limit(
self,
session,
Expand Down Expand Up @@ -1728,7 +1836,12 @@ async def before_transition(self, initial_state, proposed_state, context):
mutated_state.type = random.choice(
list(
set(states.StateType)
- {initial_state.type, proposed_state.type}
- {
initial_state.type,
proposed_state.type,
states.StateType.RUNNING,
states.StateType.CANCELLING,
}
)
)
await self.reject_transition(
Expand Down Expand Up @@ -1812,7 +1925,11 @@ async def before_transition(self, initial_state, proposed_state, context):
mutated_state.type = random.choice(
list(
set(states.StateType)
- {initial_state.type, proposed_state.type}
- {
initial_state.type,
proposed_state.type,
states.StateType.CANCELLING,
}
)
)
await self.reject_transition(
Expand Down

0 comments on commit ecaea25

Please sign in to comment.