Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: http workers pass their own id #200

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions saq/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Status,
get_default_job_key,
)
from saq.utils import now, uuid1
from saq.utils import now

if t.TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterable, Sequence
Expand Down Expand Up @@ -59,7 +59,6 @@ def __init__(
load: LoadType | None,
) -> None:
self.name = name
self.uuid: str = uuid1()
self.started: int = now()
self.complete = 0
self.failed = 0
Expand Down Expand Up @@ -132,7 +131,16 @@ async def finish_abort(self, job: Job) -> None:
await job.finish(Status.ABORTED, error=job.error)

@abstractmethod
async def write_stats(self, stats: QueueStats, ttl: int) -> None:
async def write_stats(self, worker_id: str, stats: QueueStats, ttl: int) -> None:
"""
Returns & updates stats on the queue.

Args:
worker_id: The worker id, passed in rather than taken from the queue instance to ensure that the stats
are attributed to the worker and not the queue instance in the proxy server.
stats: The stats to write.
ttl: The time-to-live in seconds.
"""
pass

@abstractmethod
Expand Down Expand Up @@ -186,16 +194,24 @@ def deserialize(self, payload: dict | str | bytes | None) -> Job | None:
raise ValueError(f"Job {job_dict} fetched by wrong queue: {self.name}")
return Job(**job_dict, queue=self)

async def stats(self, ttl: int = 60) -> QueueStats:
async def stats(self, worker_id: str, ttl: int = 60) -> QueueStats:
"""
Method to be used by workers to update stats.

Args:
worker_id: The worker id.
ttl: Time stats are valid for in seconds.

Returns: The stats.
"""
stats: QueueStats = {
"complete": self.complete,
"failed": self.failed,
"retried": self.retried,
"aborted": self.aborted,
"uptime": now() - self.started,
}

await self.write_stats(stats, ttl)
await self.write_stats(worker_id, stats, ttl)
return stats

def register_before_enqueue(self, callback: BeforeEnqueueType) -> None:
Expand Down
8 changes: 5 additions & 3 deletions saq/queue/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ async def process(self, body: str) -> str | None:
)
)
if kind == "write_stats":
await self.queue.write_stats(req["stats"], ttl=req["ttl"])
await self.queue.write_stats(
worker_id=req["worker_id"], stats=req["stats"], ttl=req["ttl"]
)
return None
raise ValueError(f"Invalid request {body}")

Expand Down Expand Up @@ -206,8 +208,8 @@ async def finish_abort(self, job: Job) -> None:
async def dequeue(self, timeout: float = 0) -> Job | None:
return self.deserialize(await self._send("dequeue", timeout=timeout))

async def write_stats(self, stats: QueueStats, ttl: int) -> None:
await self._send("write_stats", stats=stats, ttl=ttl)
async def write_stats(self, worker_id: str, stats: QueueStats, ttl: int) -> None:
await self._send("write_stats", worker_id=worker_id, stats=stats, ttl=ttl)

async def info(self, jobs: bool = False, offset: int = 0, limit: int = 10) -> QueueInfo:
return json.loads(await self._send("info", jobs=jobs, offset=offset, limit=limit))
Expand Down
4 changes: 2 additions & 2 deletions saq/queue/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ async def _enqueue(self, job: Job) -> Job | None:
logger.info("Enqueuing %s", job.info(logger.isEnabledFor(logging.DEBUG)))
return job

async def write_stats(self, stats: QueueStats, ttl: int) -> None:
async def write_stats(self, worker_id: str, stats: QueueStats, ttl: int) -> None:
async with self.pool.connection() as conn:
await conn.execute(
SQL(
Expand All @@ -684,7 +684,7 @@ async def write_stats(self, stats: QueueStats, ttl: int) -> None:
)
).format(stats_table=self.stats_table),
{
"worker_id": self.uuid,
"worker_id": worker_id,
"stats": json.dumps(stats),
"ttl": ttl,
},
Expand Down
10 changes: 2 additions & 8 deletions saq/queue/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,16 +389,10 @@ async def finish_abort(self, job: Job) -> None:
await self.redis.delete(job.abort_id)
await super().finish_abort(job)

async def write_stats(self, stats: QueueStats, ttl: int) -> None:
"""
Returns & updates stats on the queue

Args:
ttl: Time-to-live of stats saved in Redis
"""
async def write_stats(self, worker_id: str, stats: QueueStats, ttl: int) -> None:
current = now()
async with self.redis.pipeline(transaction=True) as pipe:
key = self.namespace(f"stats:{self.uuid}")
key = self.namespace(f"stats:{worker_id}")
await (
pipe.setex(key, ttl, json.dumps(stats))
.zremrangebyscore(self._stats, 0, current)
Expand Down
2 changes: 1 addition & 1 deletion saq/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class QueueInfo(t.TypedDict):

class QueueStats(t.TypedDict):
"""
Queue Stats
Queue Stats, could also be used for Worker Stats
"""

complete: int
Expand Down
14 changes: 10 additions & 4 deletions saq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from saq.job import Status
from saq.queue import Queue
from saq.utils import cancel_tasks, millis, now, now_seconds
from saq.utils import cancel_tasks, millis, now, now_seconds, uuid1

if t.TYPE_CHECKING:
from asyncio import Task
Expand All @@ -36,6 +36,7 @@
ReceivesContext,
SettingsDict,
TimersDict,
QueueStats,
)


Expand All @@ -47,6 +48,7 @@ class Worker:
Worker is used to process and monitor jobs.

Args:
id: optional override for the worker id, if not provided, uuid will be used
queue: instance of saq.queue.Queue
functions: list of async functions
concurrency: number of jobs to process concurrently
Expand All @@ -72,6 +74,7 @@ def __init__(
queue: Queue,
functions: Collection[Function | tuple[str, Function]],
*,
id: t.Optional[str] = None,
concurrency: int = 10,
cron_jobs: Collection[CronJob] | None = None,
startup: ReceivesContext | Collection[ReceivesContext] | None = None,
Expand Down Expand Up @@ -115,6 +118,7 @@ def __init__(
self.burst_jobs_processed = 0
self.burst_jobs_processed_lock = threading.Lock()
self.burst_condition_met = False
self.id = id if id is not None else uuid1()

if self.burst:
if self.dequeue_timeout <= 0:
Expand Down Expand Up @@ -151,6 +155,7 @@ async def start(self) -> None:
"""Start processing jobs and upkeep tasks."""
logger.info("Worker starting: %s", repr(self.queue))
logger.debug("Registered functions:\n%s", "\n".join(f" {key}" for key in self.functions))
await self.stats()

try:
self.event = asyncio.Event()
Expand Down Expand Up @@ -213,6 +218,9 @@ async def schedule(self, lock: int = 1) -> None:
if scheduled:
logger.info("Scheduled %s", scheduled)

async def stats(self, ttl: int = 60) -> QueueStats:
return await self.queue.stats(self.id, ttl)

async def upkeep(self) -> list[Task[None]]:
"""Start various upkeep tasks async."""

Expand All @@ -233,9 +241,7 @@ async def poll(
asyncio.create_task(poll(self.abort, self.timers["abort"])),
asyncio.create_task(poll(self.schedule, self.timers["schedule"])),
asyncio.create_task(poll(self.queue.sweep, self.timers["sweep"])),
asyncio.create_task(
poll(self.queue.stats, self.timers["stats"], self.timers["stats"] + 1)
),
asyncio.create_task(poll(self.stats, self.timers["stats"], self.timers["stats"] + 1)),
]

async def abort(self, abort_threshold: float) -> None:
Expand Down
91 changes: 91 additions & 0 deletions tests/test_http_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Validate that the worker id in the context of the http proxy is the id of the worker rather than the queue."""

import unittest
from http.server import HTTPServer, BaseHTTPRequestHandler

from saq import Queue, Worker
from saq.queue.http import HttpProxy
from saq.types import Context
from tests.helpers import setup_postgres, create_postgres_queue, teardown_postgres
import asyncio
import threading


async def echo(_ctx: Context, *, a: int) -> int:
return a


class ProxyRequestHandler(BaseHTTPRequestHandler):
def __init__(self, *args, proxy=None, **kwargs):
self.proxy = proxy
super().__init__(*args, **kwargs)

def do_POST(self):
length = int(self.headers["Content-Length"])
body = self.rfile.read(length).decode("utf-8")
response = asyncio.run(self.proxy.process(body))
if response:
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(response.encode("utf-8"))
else:
self.send_response(200)
self.end_headers()


class TestQueue(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None:
await setup_postgres()

async def asyncTearDown(self) -> None:
await teardown_postgres()

async def test_http_proxy_with_two_workers(self) -> None:
queue = await create_postgres_queue()
proxy = HttpProxy(queue=queue)

server = HTTPServer(
("localhost", 8080),
lambda *args, **kwargs: ProxyRequestHandler(*args, proxy=proxy, **kwargs),
)
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()

queue1 = Queue.from_url("http://localhost:8080/")
await queue1.connect()
queue2 = Queue.from_url("http://localhost:8080/")
await queue2.connect()

worker = Worker(
queue=queue1,
functions=[echo],
)
await worker.stats()
worker2 = Worker(
queue=queue2,
functions=[echo],
)
await worker2.stats()
local_worker = Worker(
queue=queue,
functions=[echo],
)
await local_worker.stats()

root_info = await queue.info()
info1 = await queue1.info()
info2 = await queue2.info()

self.assertEqual(root_info["workers"], info1["workers"])
self.assertEqual(info1["workers"], info2["workers"])
self.assertEqual(info1["workers"].keys(), {worker.id, worker2.id, local_worker.id})
self.assertEqual(info1["workers"].keys(), info2["workers"].keys())

await queue1.disconnect()
await queue2.disconnect()
await queue.disconnect()

server.shutdown()
server_thread.join()
17 changes: 9 additions & 8 deletions tests/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,15 @@ async def test_retry_delay(self) -> None:
self.assertEqual(job.status, Status.QUEUED)

async def test_stats(self) -> None:
worker = Worker(self.queue, functions=functions)
for _ in range(10):
await self.enqueue("test")
job = await self.dequeue()
await job.retry(None)
await job.finish(Status.ABORTED)
await job.finish(Status.FAILED)
await job.finish(Status.COMPLETE)
stats = await self.queue.stats()
stats = await worker.stats()
self.assertEqual(stats["complete"], 10)
self.assertEqual(stats["failed"], 10)
self.assertEqual(stats["retried"], 10)
Expand All @@ -221,11 +222,10 @@ async def test_info(self) -> None:
await self.enqueue("echo", a=1)
await queue2.enqueue("echo", a=1)
await worker.process()
await self.queue.stats()
await queue2.stats()
await worker.stats()

info = await self.queue.info(jobs=True)
self.assertEqual(set(info["workers"].keys()), {self.queue.uuid, queue2.uuid})
self.assertEqual(set(info["workers"].keys()), {worker.id})
self.assertEqual(info["active"], 0)
self.assertEqual(info["queued"], 1)
self.assertEqual(len(info["jobs"]), 1)
Expand Down Expand Up @@ -580,8 +580,9 @@ async def test_sweep_jobs(self) -> None:
self.assertEqual(job2.status, Status.COMPLETE)

async def test_sweep_stats(self) -> None:
worker = Worker(self.queue, functions=functions)
# Stats are deleted
await self.queue.stats(ttl=1)
await worker.stats(ttl=1)
await asyncio.sleep(1.5)
await self.queue.sweep()
async with self.queue.pool.connection() as conn, conn.cursor() as cursor:
Expand All @@ -593,12 +594,12 @@ async def test_sweep_stats(self) -> None:
WHERE worker_id = %s
"""
).format(self.queue.stats_table),
(self.queue.uuid,),
(worker.id,),
)
self.assertIsNone(await cursor.fetchone())

# Stats are not deleted
await self.queue.stats(ttl=60)
await worker.stats(ttl=60)
await asyncio.sleep(1)
await self.queue.sweep()
async with self.queue.pool.connection() as conn, conn.cursor() as cursor:
Expand All @@ -610,7 +611,7 @@ async def test_sweep_stats(self) -> None:
WHERE worker_id = %s
"""
).format(self.queue.stats_table),
(self.queue.uuid,),
(worker.id,),
)
self.assertIsNotNone(await cursor.fetchone())

Expand Down
Loading