diff --git a/changelog.d/15891.feature b/changelog.d/15891.feature new file mode 100644 index 000000000000..5a3d12a32e2f --- /dev/null +++ b/changelog.d/15891.feature @@ -0,0 +1 @@ +Implements a task scheduler. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index dc79efcc142f..d25e3548e075 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -91,6 +91,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore +from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore @@ -144,6 +145,7 @@ class GenericWorkerStore( TransactionWorkerStore, LockStore, SessionStore, + TaskSchedulerWorkerStore, ): # Properties that multiple storage classes define. Tell mypy what the # expected type is. diff --git a/synapse/handlers/task_scheduler.py b/synapse/handlers/task_scheduler.py new file mode 100644 index 000000000000..ea124753aecb --- /dev/null +++ b/synapse/handlers/task_scheduler.py @@ -0,0 +1,111 @@ +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set + +import attr + +from synapse.api.errors import SynapseError +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import JsonMapping, ScheduledTask, TaskState +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class TaskSchedulerHandler: + SCHEDULING_INTERVAL_MS = 10 * 60 * 1000 # 10mn + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + self._is_master = hs.config.worker.worker_app is None + self.running_tasks: Set[str] = set() + self.actions: Dict[ + str, Callable[[ScheduledTask], Awaitable[Optional[ScheduledTask]]] + ] = {} + + if self._is_master: + self.clock.looping_call( + run_as_background_process, + TaskSchedulerHandler.SCHEDULING_INTERVAL_MS, + "scheduled_tasks_loop", + self._scheduled_tasks_loop, + ) + + def bind_action( + self, + fct: Callable[[ScheduledTask], Awaitable[Optional[ScheduledTask]]], + action_name: str, + ) -> None: + self.actions[action_name] = fct + + async def schedule_task( + self, + action: str, + *, + resource_id: Optional[str] = None, + timestamp: Optional[int] = None, + params: Optional[JsonMapping] = None, + ) -> str: + if action not in self.actions: + # TODO + raise SynapseError(400, "Test") + task_id = random_string(16) + state = TaskState.SCHEDULED + if timestamp is None or timestamp < self.clock.time_msec(): + state = TaskState.RUNNING + timestamp = self.clock.time_msec() + + task = ScheduledTask( + task_id, + action, + state, + resource_id, + timestamp, + params, + None, + ) + await self.store.upsert_scheduled_task(task) + return task_id + + async def update_task_state( + self, + task: ScheduledTask, + # error: Optional[str], + ) -> None: + await self.store.upsert_scheduled_task(task) + + async def get_task(self, id: str) -> Optional[ScheduledTask]: + return await self.store.get_scheduled_task(id) + + async def get_tasks( + self, action: str, resource_id: Optional[str] + ) -> List[ScheduledTask]: + return await self.store.get_scheduled_tasks(action, resource_id) + + async def _scheduled_tasks_loop(self) -> None: + for task in await self.store.get_scheduled_tasks(): + if task.id not in self.running_tasks: + state = task.state + if ( + state == TaskState.SCHEDULED + and task.timestamp is not None + and task.timestamp < self.clock.time_msec() + ): + state = TaskState.RUNNING + + if state == TaskState.RUNNING: + await self.store.upsert_scheduled_task(task) + self._run_task(task) + + def _run_task(self, task: ScheduledTask) -> None: + if task.action in self.actions: + fct = self.actions[task.action] + + async def wrapper() -> None: + updated_task = await fct(task) + if updated_task is None: + updated_task = attr.evolve(task, state=TaskState.COMPLETE) + await self.update_task_state(updated_task) + + run_as_background_process(task.action, wrapper) + self.running_tasks.add(task.id) diff --git a/synapse/server.py b/synapse/server.py index b72b76a38b35..c0e1277f6ae4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -105,6 +105,7 @@ from synapse.handlers.sso import SsoHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler +from synapse.handlers.task_scheduler import TaskSchedulerHandler from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.user_directory import UserDirectoryHandler from synapse.http.client import ( @@ -242,6 +243,7 @@ class HomeServer(metaclass=abc.ABCMeta): "profile", "room_forgetter", "stats", + "task_scheduler", ] # This is overridden in derived application classes @@ -912,3 +914,7 @@ def get_request_ratelimiter(self) -> RequestRatelimiter: def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager: """Usage metrics shared between phone home stats and the prometheus exporter.""" return CommonUsageMetricsManager(self) + + @cache_in_self + def get_task_scheduler_handler(self) -> TaskSchedulerHandler: + return TaskSchedulerHandler(self) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 80c0304b1917..3ce82d5625f3 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -70,6 +70,7 @@ from .stats import StatsStore from .stream import StreamWorkerStore from .tags import TagsStore +from .task_scheduler import TaskSchedulerWorkerStore from .transactions import TransactionWorkerStore from .ui_auth import UIAuthStore from .user_directory import UserDirectoryStore @@ -127,6 +128,7 @@ class DataStore( CacheInvalidationWorkerStore, LockStore, SessionStore, + TaskSchedulerWorkerStore, ): def __init__( self, diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py new file mode 100644 index 000000000000..0d279948e656 --- /dev/null +++ b/synapse/storage/databases/main/task_scheduler.py @@ -0,0 +1,100 @@ +import json +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.types import ScheduledTask, TaskState + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class TaskSchedulerWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + @staticmethod + def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask: + row["state"] = TaskState(row["state"]) + if row["params"] is not None: + row["params"] = json.loads(row["params"]) + if row["result"] is not None: + row["result"] = json.loads(row["result"]) + return ScheduledTask(**row) + + async def get_scheduled_tasks( + self, action: Optional[str] = None, resource_id: Optional[str] = None + ) -> List[ScheduledTask]: + keyvalues = {} + if action: + keyvalues["action"] = action + if resource_id: + keyvalues["resource_id"] = resource_id + + rows = await self.db_pool.simple_select_list( + table="scheduled_tasks", + keyvalues=keyvalues, + retcols=( + "id", + "action", + "state", + "timestamp", + "resource_id", + "params", + "result", + # "error", + ), + desc="get_scheduled_tasks", + ) + + return [TaskSchedulerWorkerStore._convert_row_to_task(row) for row in rows] + + async def upsert_scheduled_task(self, task: ScheduledTask) -> None: + await self.db_pool.simple_upsert( + "scheduled_tasks", + {"id": task.id}, + { + "action": task.action, + "state": task.state, + "resource_id": task.resource_id, + "timestamp": task.timestamp, + "params": None if task.params is None else json.dumps(task.params), + "result": None if task.result is None else json.dumps(task.result), + # "error": task.error, + }, + desc="upsert_scheduled_task", + ) + + async def get_scheduled_task(self, id: str) -> Optional[ScheduledTask]: + row = await self.db_pool.simple_select_one( + table="scheduled_tasks", + keyvalues={"id": id}, + retcols=( + "id", + "action", + "state", + "resource_id", + "timestamp", + "params", + "result", + # "error", + ), + desc="get_scheduled_task", + ) + + return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None + + async def delete_scheduled_task(self, id: str) -> bool: + return ( + await self.db_pool.simple_delete( + "scheduled_tasks", + keyvalues={"id": id}, + desc="delete_scheduled_task", + ) + > 0 + ) diff --git a/synapse/storage/schema/main/delta/78/05_scheduled_tasks.sql b/synapse/storage/schema/main/delta/78/05_scheduled_tasks.sql new file mode 100644 index 000000000000..92760e7c986f --- /dev/null +++ b/synapse/storage/schema/main/delta/78/05_scheduled_tasks.sql @@ -0,0 +1,26 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- cf ScheduledTask docstring for the meaning of the fields. +CREATE TABLE IF NOT EXISTS scheduled_tasks( + id text PRIMARY KEY, + action text NOT NULL, + state text NOT NULL, + resource_id text, + timestamp bigint, + params text, + result text + -- error text +); diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 095be070e0c5..561870f323c0 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -15,6 +15,7 @@ import abc import re import string +from enum import Enum from typing import ( TYPE_CHECKING, AbstractSet, @@ -979,3 +980,23 @@ class UserProfile(TypedDict): class RetentionPolicy: min_lifetime: Optional[int] = None max_lifetime: Optional[int] = None + + +class TaskState(str, Enum): + SCHEDULED = "scheduled" + RUNNING = "running" + COMPLETE = "complete" + FAILED = "failed" + ABORTED = "aborted" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class ScheduledTask: + id: str + action: str + state: TaskState + resource_id: Optional[str] + timestamp: Optional[int] + params: Optional[JsonMapping] + result: Optional[JsonDict] + # error: Optional[str] diff --git a/tests/handlers/test_task_scheduler.py b/tests/handlers/test_task_scheduler.py new file mode 100644 index 000000000000..de7b2fd14523 --- /dev/null +++ b/tests/handlers/test_task_scheduler.py @@ -0,0 +1,47 @@ +from typing import Optional + +import attr + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.types import ScheduledTask, TaskState +from synapse.util import Clock + +from tests import unittest + + +class TestTaskScheduler(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.handler = hs.get_task_scheduler_handler() + self.handler.bind_action(self._test_task, "test_action") + + async def _test_task(self, task: ScheduledTask) -> Optional[ScheduledTask]: + if task.params: + val = task.params.get("val") + task = attr.evolve(task, state=TaskState.COMPLETE, result={"val": val}) + return task + return None + + def test_schedule_task(self) -> None: + timestamp = self.clock.time_msec() + 5 * 60 * 1000 + task_id = self.get_success( + self.handler.schedule_task( + "test_action", + timestamp=timestamp, + params={"val": 1}, + ) + ) + + running_task = self.get_success(self.handler.get_task(task_id)) + assert running_task is not None + self.assertEqual(running_task.state, TaskState.SCHEDULED) + self.assertIsNone(running_task.result) + + self.reactor.advance(20 * 60) + + running_task = self.get_success(self.handler.get_task(task_id)) + assert running_task is not None + self.assertEqual(running_task.state, TaskState.COMPLETE) + assert running_task.result is not None + self.assertTrue(running_task.result.get("val") == 1)