Skip to content

Commit

Permalink
refactor: revamp update mechanism of session & kernel status (#2311)
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa authored Sep 13, 2024
1 parent 2ac640f commit 195261c
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 402 deletions.
1 change: 1 addition & 0 deletions changes/2311.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enhance update mechanism of session & kernel status.
1 change: 1 addition & 0 deletions src/ai/backend/manager/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class LockID(enum.IntEnum):
LOCKID_SCALE_TIMER = 193
LOCKID_LOG_CLEANUP_TIMER = 195
LOCKID_IDLE_CHECK_TIMER = 196
LOCKID_SESSION_STATUS_UPDATE_TIMER = 197


SERVICE_MAX_RETRIES = 5 # FIXME: make configurable
Expand Down
4 changes: 4 additions & 0 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,8 +651,12 @@ async def get_kernel_to_update_status(
cls,
db_session: SASession,
kernel_id: KernelId,
*,
for_update: bool = True,
) -> KernelRow:
_stmt = sa.select(KernelRow).where(KernelRow.id == kernel_id)
if for_update:
_stmt = _stmt.with_for_update()
kernel_row = cast(KernelRow | None, await db_session.scalar(_stmt))
if kernel_row is None:
raise KernelNotFound(f"Kernel not found (id:{kernel_id})")
Expand Down
265 changes: 208 additions & 57 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import asyncio
import enum
import logging
import textwrap
from collections.abc import Iterable, Mapping, Sequence
from contextlib import asynccontextmanager as actxmgr
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -21,6 +23,7 @@

import aiotools
import graphene
import redis.exceptions
import sqlalchemy as sa
from dateutil.parser import parse as dtparse
from dateutil.tz import tzutc
Expand All @@ -30,10 +33,20 @@
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, noload, relationship, selectinload

from ai.backend.common import redis_helper
from ai.backend.common.events import (
EventDispatcher,
EventProducer,
SessionStartedEvent,
SessionTerminatedEvent,
)
from ai.backend.common.plugin.hook import HookPluginContext
from ai.backend.common.types import (
AccessKey,
ClusterMode,
KernelId,
RedisConnectionInfo,
ResourceSlot,
SessionId,
SessionResult,
SessionTypes,
Expand Down Expand Up @@ -95,6 +108,7 @@
JSONCoalesceExpr,
agg_to_array,
execute_with_retry,
execute_with_txn_retry,
sql_json_merge,
)

Expand Down Expand Up @@ -795,62 +809,6 @@ async def get_session_id_by_kernel(
async with db.begin_readonly_session() as db_session:
return await db_session.scalar(query)

@classmethod
async def transit_session_status(
cls,
db: ExtendedAsyncSAEngine,
session_id: SessionId,
*,
status_info: str | None = None,
) -> SessionStatus | None:
"""
Check status of session's sibling kernels and transit the status of session.
Return the new status of session.
"""
now = datetime.now(tzutc())

async def _check_and_update() -> SessionStatus | None:
async with db.begin_session() as db_session:
session_query = (
sa.select(SessionRow)
.where(SessionRow.id == session_id)
.with_for_update()
.options(
noload("*"),
load_only(SessionRow.status),
selectinload(SessionRow.kernels).options(
noload("*"), load_only(KernelRow.status, KernelRow.cluster_role)
),
)
)
session_row: SessionRow = (await db_session.scalars(session_query)).first()
determined_status = determine_session_status(session_row.kernels)
if determined_status not in SESSION_STATUS_TRANSITION_MAP[session_row.status]:
# TODO: log or raise error
return None

update_values = {
"status": determined_status,
"status_history": sql_json_merge(
SessionRow.status_history,
(),
{
determined_status.name: now.isoformat(),
},
),
}
if determined_status in (SessionStatus.CANCELLED, SessionStatus.TERMINATED):
update_values["terminated_at"] = now
if status_info is not None:
update_values["status_info"] = status_info
update_query = (
sa.update(SessionRow).where(SessionRow.id == session_id).values(**update_values)
)
await db_session.execute(update_query)
return determined_status

return await execute_with_retry(_check_and_update)

@classmethod
async def get_session_to_determine_status(
cls, db_session: SASession, session_id: SessionId
Expand All @@ -860,7 +818,12 @@ async def get_session_to_determine_status(
.where(SessionRow.id == session_id)
.options(
selectinload(SessionRow.kernels).options(
load_only(KernelRow.status, KernelRow.cluster_role, KernelRow.status_info)
load_only(
KernelRow.status,
KernelRow.cluster_role,
KernelRow.status_info,
KernelRow.occupied_slots,
)
),
)
)
Expand Down Expand Up @@ -1218,6 +1181,194 @@ async def get_sgroup_managed_sessions(
return result.scalars().all()


class SessionLifecycleManager:
status_set_key = "session_status_update"

def __init__(
self,
db: ExtendedAsyncSAEngine,
redis_obj: RedisConnectionInfo,
event_dispatcher: EventDispatcher,
event_producer: EventProducer,
hook_plugin_ctx: HookPluginContext,
) -> None:
self.db = db
self.redis_obj = redis_obj
self.event_dispatcher = event_dispatcher
self.event_producer = event_producer
self.hook_plugin_ctx = hook_plugin_ctx

def _encode(sid: SessionId) -> bytes:
return sid.bytes

def _decode(raw_sid: bytes) -> SessionId:
return SessionId(UUID(bytes=raw_sid))

self._encoder = _encode
self._decoder = _decode

async def _transit_session_status(
self,
db_conn: SAConnection,
session_id: SessionId,
status_changed_at: datetime | None = None,
) -> tuple[SessionRow, bool]:
now = status_changed_at or datetime.now(tzutc())

async def _get_and_transit(
db_session: SASession,
) -> tuple[SessionRow, bool]:
session_row = await SessionRow.get_session_to_determine_status(db_session, session_id)
transited = session_row.determine_and_set_status(status_changed_at=now)

def _calculate_session_occupied_slots(session_row: SessionRow):
session_occupying_slots = ResourceSlot.from_json({**session_row.occupying_slots})
for row in session_row.kernels:
kernel_row = cast(KernelRow, row)
kernel_allocs = kernel_row.occupied_slots
session_occupying_slots.sync_keys(kernel_allocs)
for key, val in session_occupying_slots.items():
session_occupying_slots[key] = str(
Decimal(val) + Decimal(kernel_allocs[key])
)
session_row.occupying_slots = session_occupying_slots

match session_row.status:
case SessionStatus.PREPARING | SessionStatus.RUNNING:
_calculate_session_occupied_slots(session_row)
return session_row, transited

return await execute_with_txn_retry(_get_and_transit, self.db.begin_session, db_conn)

async def _post_status_transition(
self,
session_row: SessionRow,
) -> None:
match session_row.status:
case SessionStatus.RUNNING:
log.debug(
"Producing SessionStartedEvent({}, {})",
session_row.id,
session_row.creation_id,
)
await self.event_producer.produce_event(
SessionStartedEvent(session_row.id, session_row.creation_id),
)
await self.hook_plugin_ctx.notify(
"POST_START_SESSION",
(
session_row.id,
session_row.name,
session_row.access_key,
),
)
await self.event_producer.produce_event(
SessionStartedEvent(session_row.id, session_row.creation_id),
)
case SessionStatus.TERMINATED:
await self.event_producer.produce_event(
SessionTerminatedEvent(session_row.id, session_row.main_kernel.status_info),
)
case _:
pass

async def transit_session_status(
self,
session_ids: Iterable[SessionId],
status_changed_at: datetime | None = None,
) -> list[SessionRow]:
if not session_ids:
return []
now = status_changed_at or datetime.now(tzutc())
transited_sessions: list[SessionRow] = []
async with self.db.connect() as db_conn:
for sid in session_ids:
row, is_transited = await self._transit_session_status(db_conn, sid, now)
if is_transited:
transited_sessions.append(row)
for row in transited_sessions:
await self._post_status_transition(row)
return transited_sessions

async def register_status_updatable_session(self, session_ids: Iterable[SessionId]) -> None:
sadd_session_ids_script = textwrap.dedent("""
local key = KEYS[1]
local values = ARGV
return redis.call('SADD', key, unpack(values))
""")
try:
await redis_helper.execute_script(
self.redis_obj,
"session_status_update",
sadd_session_ids_script,
[self.status_set_key],
[self._encoder(sid) for sid in session_ids],
)
except (
redis.exceptions.RedisError,
redis.exceptions.RedisClusterException,
redis.exceptions.ChildDeadlockedError,
) as e:
log.warning(f"Failed to update session status to redis, skip. (e:{repr(e)})")

async def get_status_updatable_sessions(self) -> list[SessionId]:
pop_all_session_id_script = textwrap.dedent("""
local key = KEYS[1]
local count = redis.call('SCARD', key)
return redis.call('SPOP', key, count)
""")
try:
raw_result = await redis_helper.execute_script(
self.redis_obj,
"pop_all_session_id_to_update_status",
pop_all_session_id_script,
[self.status_set_key],
[],
)
except (
redis.exceptions.RedisError,
redis.exceptions.RedisClusterException,
redis.exceptions.ChildDeadlockedError,
) as e:
log.warning(f"Failed to fetch session status data from redis, skip. (e:{repr(e)})")
return []
raw_result = cast(list[bytes], raw_result)
result: list[SessionId] = []
for raw_session_id in raw_result:
try:
result.append(self._decoder(raw_session_id))
except (ValueError, SyntaxError):
log.warning(f"Cannot parse session id, skip. (id:{raw_session_id})")
continue
return result

async def deregister_status_updatable_session(
self,
session_ids: Iterable[SessionId],
) -> int:
srem_session_ids_script = textwrap.dedent("""
local key = KEYS[1]
local values = ARGV
return redis.call('SREM', key, unpack(values))
""")
try:
ret = await redis_helper.execute_script(
self.redis_obj,
"session_status_update",
srem_session_ids_script,
[self.status_set_key],
[self._encoder(sid) for sid in session_ids],
)
except (
redis.exceptions.RedisError,
redis.exceptions.RedisClusterException,
redis.exceptions.ChildDeadlockedError,
) as e:
log.warning(f"Failed to remove session status data from redis, skip. (e:{repr(e)})")
return 0
return ret


class SessionDependencyRow(Base):
__tablename__ = "session_dependencies"
session_id = sa.Column(
Expand Down
Loading

0 comments on commit 195261c

Please sign in to comment.