Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a CANCELLING state type #7794

Merged
merged 9 commits into from
Dec 15, 2022
4 changes: 4 additions & 0 deletions src/prefect/orion/database/migrations/MIGRATION-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Each time a database migration is written, an entry is included here with:

This gives us a history of changes and will create merge conflicts if two migrations are made at once, flagging situations where a branch needs to be updated before merging.

# Add `CANCELLING` to StateType enum
SQLite: None
Postgres: `9326a6aee18b`

# Add infrastructure_pid to flow runs
SQLite: `7201de756d85`
Postgres: `5d526270ddb4`
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Add CANCELLING to state type enum

Revision ID: 9326a6aee18b
Revises: 5e4f924ff96c
Create Date: 2022-12-06 16:40:28.282753

"""
from alembic import op

# revision identifiers, used by Alembic.
revision = "9326a6aee18b"
down_revision = "5e4f924ff96c"
branch_labels = None
depends_on = None


def upgrade():
op.execute("ALTER TYPE state_type ADD VALUE IF NOT EXISTS 'CANCELLING';")


def downgrade():
pass
2 changes: 1 addition & 1 deletion src/prefect/orion/database/query_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def get_scheduled_flow_runs_from_work_queues(
db.FlowRun,
sa.and_(
self._flow_run_work_queue_join_clause(db.FlowRun, db.WorkQueue),
db.FlowRun.state_type.in_(["RUNNING", "PENDING"]),
db.FlowRun.state_type.in_(["RUNNING", "PENDING", "CANCELLING"]),
),
isouter=True,
)
Expand Down
26 changes: 22 additions & 4 deletions src/prefect/orion/orchestration/core_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ async def cleanup(

class ReleaseTaskConcurrencySlots(BaseUniversalTransform):
"""
Releases any concurrency slots held by a run upon exiting a Running state.
Releases any concurrency slots held by a run upon exiting a Running or
Cancelling state.
"""

async def after_transition(
Expand All @@ -172,7 +173,10 @@ async def after_transition(
if self.nullified_transition():
return

if not context.validated_state.is_running():
if context.validated_state.type not in [
states.StateType.RUNNING,
states.StateType.CANCELLING,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will need a matching update in cloud

]:
filtered_limits = (
await concurrency_limits.filter_concurrency_limits_for_orchestration(
context.session, tags=context.run.tags
Expand Down Expand Up @@ -698,10 +702,23 @@ class PreventRedundantTransitions(BaseOrchestrationRule):
StateType.SCHEDULED: 1,
StateType.PENDING: 2,
StateType.RUNNING: 3,
StateType.CANCELLING: 4,
}

FROM_STATES = [StateType.SCHEDULED, StateType.PENDING, StateType.RUNNING, None]
TO_STATES = [StateType.SCHEDULED, StateType.PENDING, StateType.RUNNING, None]
FROM_STATES = [
StateType.SCHEDULED,
StateType.PENDING,
StateType.RUNNING,
StateType.CANCELLING,
None,
]
TO_STATES = [
StateType.SCHEDULED,
StateType.PENDING,
StateType.RUNNING,
StateType.CANCELLING,
None,
]

async def before_transition(
self,
Expand All @@ -711,6 +728,7 @@ async def before_transition(
) -> None:
initial_state_type = initial_state.type if initial_state else None
proposed_state_type = proposed_state.type if proposed_state else None

if (
self.STATE_PROGRESS[proposed_state_type]
<= self.STATE_PROGRESS[initial_state_type]
Expand Down
10 changes: 10 additions & 0 deletions src/prefect/orion/schemas/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class StateType(AutoEnum):
CANCELLED = AutoEnum.auto()
CRASHED = AutoEnum.auto()
PAUSED = AutoEnum.auto()
CANCELLING = AutoEnum.auto()


TERMINAL_STATES = {
Expand Down Expand Up @@ -267,6 +268,15 @@ def Crashed(cls: Type[State] = State, **kwargs) -> State:
return cls(type=StateType.CRASHED, **kwargs)


def Cancelling(cls: Type[State] = State, **kwargs) -> State:
"""Convenience function for creating `Cancelling` states.

Returns:
State: a Cancelling state
"""
return cls(type=StateType.CANCELLING, **kwargs)


def Cancelled(cls: Type[State] = State, **kwargs) -> State:
"""Convenience function for creating `Cancelled` states.

Expand Down
24 changes: 19 additions & 5 deletions tests/orion/models/test_work_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ async def test_delete_work_queue_returns_false_if_does_not_exist(self, session):


class TestGetRunsInWorkQueue:
running_flow_states = [
schemas.states.StateType.PENDING,
schemas.states.StateType.CANCELLING,
schemas.states.StateType.RUNNING,
bunchesofdonald marked this conversation as resolved.
Show resolved Hide resolved
]

@pytest.fixture
async def work_queue_2(self, session):
work_queue = await models.work_queues.create_work_queue(
Expand Down Expand Up @@ -270,7 +276,7 @@ async def scheduled_flow_runs(self, session, deployment, work_queue, work_queue_

@pytest.fixture
async def running_flow_runs(self, session, deployment, work_queue, work_queue_2):
for i in range(3):
for state_type in self.running_flow_states:
for wq in [work_queue, work_queue_2]:
await models.flow_runs.create_flow_run(
session=session,
Expand All @@ -279,7 +285,7 @@ async def running_flow_runs(self, session, deployment, work_queue, work_queue_2)
deployment_id=deployment.id,
work_queue_name=wq.name,
state=schemas.states.State(
type="RUNNING" if i == 0 else "PENDING",
type=state_type,
timestamp=pendulum.now("UTC").subtract(seconds=10),
),
),
Expand Down Expand Up @@ -368,7 +374,9 @@ async def test_get_runs_in_queue_concurrency_limit(
session=session, work_queue_id=work_queue.id
)

assert len(runs_wq1) == max(0, min(3, concurrency_limit - 3))
assert len(runs_wq1) == max(
0, min(3, concurrency_limit - len(self.running_flow_states))
)

@pytest.mark.parametrize("limit", [10, 1])
async def test_get_runs_in_queue_concurrency_limit_and_limit(
Expand All @@ -379,14 +387,20 @@ async def test_get_runs_in_queue_concurrency_limit_and_limit(
running_flow_runs,
limit,
):
concurrency_limit = 5

await models.work_queues.update_work_queue(
session=session,
work_queue_id=work_queue.id,
work_queue=schemas.actions.WorkQueueUpdate(concurrency_limit=5),
work_queue=schemas.actions.WorkQueueUpdate(
concurrency_limit=concurrency_limit
),
)

runs_wq1 = await models.work_queues.get_runs_in_work_queue(
session=session, work_queue_id=work_queue.id, limit=limit
)

assert len(runs_wq1) == min(limit, 2)
assert len(runs_wq1) == min(
limit, concurrency_limit - len(self.running_flow_states)
)
121 changes: 119 additions & 2 deletions tests/orion/orchestration/test_core_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ async def test_all_other_transitions_are_accepted(
@pytest.mark.parametrize("run_type", ["task", "flow"])
class TestPreventingRedundantTransitionsRule:
active_states = (
states.StateType.CANCELLING,
states.StateType.RUNNING,
states.StateType.PENDING,
states.StateType.SCHEDULED,
Expand Down Expand Up @@ -1147,6 +1148,113 @@ async def test_basic_concurrency_limiting(
assert task2_run_retry_ctx.response_status == SetStateStatus.ACCEPT
assert (await self.count_concurrency_slots(session, "some tag")) == 1

async def test_concurrency_limit_cancelling_transition(
self,
session,
run_type,
initialize_orchestration,
):
await self.create_concurrency_limit(session, "some tag", 1)
concurrency_policy = [SecureTaskConcurrencySlots, ReleaseTaskConcurrencySlots]
running_transition = (states.StateType.PENDING, states.StateType.RUNNING)
cancelling_transition = (states.StateType.RUNNING, states.StateType.CANCELLING)
cancelled_transition = (states.StateType.CANCELLING, states.StateType.CANCELLED)

# before any runs, no active concurrency slots are in use
assert (await self.count_concurrency_slots(session, "some tag")) == 0

task1_running_ctx = await initialize_orchestration(
session, "task", *running_transition, run_tags=["some tag"]
)

async with contextlib.AsyncExitStack() as stack:
for rule in concurrency_policy:
task1_running_ctx = await stack.enter_async_context(
rule(task1_running_ctx, *running_transition)
)
await task1_running_ctx.validate_proposed_state()

# a first task run against a concurrency limited tag will be accepted
assert task1_running_ctx.response_status == SetStateStatus.ACCEPT

task2_running_ctx = await initialize_orchestration(
session, "task", *running_transition, run_tags=["some tag"]
)

async with contextlib.AsyncExitStack() as stack:
for rule in concurrency_policy:
task2_running_ctx = await stack.enter_async_context(
rule(task2_running_ctx, *running_transition)
)
await task2_running_ctx.validate_proposed_state()

# the first task hasn't completed, so the concurrently running second task is
# told to wait
assert task2_running_ctx.response_status == SetStateStatus.WAIT

# the number of slots occupied by active runs is equal to the concurrency limit
assert (await self.count_concurrency_slots(session, "some tag")) == 1

task1_cancelling_ctx = await initialize_orchestration(
session,
"task",
*cancelling_transition,
run_override=task1_running_ctx.run,
run_tags=["some tag"],
)

async with contextlib.AsyncExitStack() as stack:
for rule in concurrency_policy:
task1_cancelling_ctx = await stack.enter_async_context(
rule(task1_cancelling_ctx, *cancelling_transition)
)
await task1_cancelling_ctx.validate_proposed_state()

# the first task run will transition into a cancelling state, but
# maintain a hold on the concurrency slot
assert task1_cancelling_ctx.response_status == SetStateStatus.ACCEPT
assert (await self.count_concurrency_slots(session, "some tag")) == 1

task1_cancelled_ctx = await initialize_orchestration(
session,
"task",
*cancelled_transition,
run_override=task1_running_ctx.run,
run_tags=["some tag"],
)

async with contextlib.AsyncExitStack() as stack:
for rule in concurrency_policy:
task1_cancelled_ctx = await stack.enter_async_context(
rule(task1_cancelled_ctx, *cancelled_transition)
)
await task1_cancelled_ctx.validate_proposed_state()

# the first task run will transition into a cancelled state, yielding a
# concurrency slot
assert task1_cancelled_ctx.response_status == SetStateStatus.ACCEPT
assert (await self.count_concurrency_slots(session, "some tag")) == 0

# the second task tries to run again, this time the transition will be accepted
# now that a concurrency slot has been freed
task2_run_retry_ctx = await initialize_orchestration(
session,
"task",
*running_transition,
run_override=task2_running_ctx.run,
run_tags=["some tag"],
)

async with contextlib.AsyncExitStack() as stack:
for rule in concurrency_policy:
task2_run_retry_ctx = await stack.enter_async_context(
rule(task2_run_retry_ctx, *running_transition)
)
await task2_run_retry_ctx.validate_proposed_state()

assert task2_run_retry_ctx.response_status == SetStateStatus.ACCEPT
assert (await self.count_concurrency_slots(session, "some tag")) == 1

async def test_concurrency_limiting_aborts_transitions_with_zero_limit(
self,
session,
Expand Down Expand Up @@ -1575,7 +1683,12 @@ async def before_transition(self, initial_state, proposed_state, context):
mutated_state.type = random.choice(
list(
set(states.StateType)
- {initial_state.type, proposed_state.type}
- {
initial_state.type,
proposed_state.type,
states.StateType.RUNNING,
states.StateType.CANCELLING,
}
)
)
await self.reject_transition(
Expand Down Expand Up @@ -1659,7 +1772,11 @@ async def before_transition(self, initial_state, proposed_state, context):
mutated_state.type = random.choice(
list(
set(states.StateType)
- {initial_state.type, proposed_state.type}
- {
initial_state.type,
proposed_state.type,
states.StateType.CANCELLING,
}
)
)
await self.reject_transition(
Expand Down