From d69bb5223f7a52a095789e6873fcf0554eb4dbdf Mon Sep 17 00:00:00 2001 From: Alessio Siniscalchi Date: Wed, 8 Jan 2025 12:47:54 +0100 Subject: [PATCH 1/9] create multiple buckets at init --- cads_broker/entry_points.py | 13 ++++++++----- cads_broker/object_storage.py | 10 ++++++++-- tests/test_90_entry_points.py | 15 +++++++++++---- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/cads_broker/entry_points.py b/cads_broker/entry_points.py index e52aace6..e09dc7b1 100644 --- a/cads_broker/entry_points.py +++ b/cads_broker/entry_points.py @@ -6,10 +6,11 @@ import random import uuid from pathlib import Path -from typing import Any, Optional +from typing import Any, List, Optional import sqlalchemy as sa import typer +from cads_worker import utils from typing_extensions import Annotated from cads_broker import config, database, dispatcher, object_storage @@ -245,10 +246,12 @@ def init_db(connection_string: Optional[str] = None, force: bool = False) -> Non "aws_access_key_id": os.environ["STORAGE_ADMIN"], "aws_secret_access_key": os.environ["STORAGE_PASSWORD"], } - object_storage.create_download_bucket( - os.environ.get("CACHE_BUCKET", "cache"), object_storage_url, **storage_kws - ) - print("successfully created the cache area in the object storage.") + download_buckets: List[str] = utils.parse_data_volumes_config() + for download_bucket in download_buckets: + object_storage.create_download_bucket( + download_bucket, object_storage_url, **storage_kws + ) + print("successfully created the cache areas in the object storage.") @app.command() diff --git a/cads_broker/object_storage.py b/cads_broker/object_storage.py index 74a601d5..830573b0 100644 --- a/cads_broker/object_storage.py +++ b/cads_broker/object_storage.py @@ -1,5 +1,6 @@ """utility module to interface to the object storage.""" +import urllib.parse from typing import Any import boto3 # type: ignore @@ -43,13 +44,18 @@ def create_download_bucket( Parameters ---------- - bucket_name: name of the bucket + bucket_name: name of the bucket (something as 's3://mybucketname' or just 'mybucketname') object_storage_url: endpoint URL of the object storage client: client to use, default is boto3 (used for testing) storage_kws: dictionary of parameters used to pass to the storage client. """ + bucket_url_obj = urllib.parse.urlparse(bucket_name) + scheme = "s3" + if bucket_url_obj.scheme: + scheme = bucket_url_obj.scheme + bucket_name = bucket_url_obj.netloc if not client: - client = boto3.client("s3", endpoint_url=object_storage_url, **storage_kws) + client = boto3.client(scheme, endpoint_url=object_storage_url, **storage_kws) if not is_bucket_existing(client, bucket_name): logger.info(f"creation of bucket {bucket_name}") client.create_bucket(Bucket=bucket_name) diff --git a/tests/test_90_entry_points.py b/tests/test_90_entry_points.py index b232a488..b7345403 100644 --- a/tests/test_90_entry_points.py +++ b/tests/test_90_entry_points.py @@ -1,4 +1,6 @@ import datetime +import os +import unittest.mock import uuid from typing import Any @@ -57,8 +59,11 @@ def mock_config( return adaptor_properties -def test_init_db(postgresql: Connection[str], mocker) -> None: +def test_init_db(postgresql: Connection[str], tmpdir, mocker) -> None: patch_storage = mocker.patch.object(object_storage, "create_download_bucket") + data_volumes_config_path = os.path.join(str(tmpdir), "data_volumes.config") + with open(data_volumes_config_path, "w") as fp: + fp.writelines(["s3://mybucket1\n", "s3://mybucket2\n"]) connection_string = ( f"postgresql://{postgresql.info.user}:" f"@{postgresql.info.host}:{postgresql.info.port}/{postgresql.info.dbname}" @@ -80,12 +85,14 @@ def test_init_db(postgresql: Connection[str], mocker) -> None: "OBJECT_STORAGE_URL": object_storage_url, "STORAGE_ADMIN": object_storage_kws["aws_access_key_id"], "STORAGE_PASSWORD": object_storage_kws["aws_secret_access_key"], + "DATA_VOLUMES_CONFIG": data_volumes_config_path, }, ) assert result.exit_code == 0 - patch_storage.assert_called_once_with( - "cache", object_storage_url, **object_storage_kws - ) + assert patch_storage.mock_calls == [ + unittest.mock.call("s3://mybucket1", object_storage_url, **object_storage_kws), + unittest.mock.call("s3://mybucket2", object_storage_url, **object_storage_kws), + ] assert set(conn.execute(query).scalars()) == set( database.BaseModel.metadata.tables ).union( From 874792ef463c3255389312665a93901726caf8de Mon Sep 17 00:00:00 2001 From: Alessio Siniscalchi Date: Wed, 8 Jan 2025 13:16:46 +0100 Subject: [PATCH 2/9] removed dependency of cads_worker from entry_points --- cads_broker/entry_points.py | 3 +-- cads_broker/object_storage.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/cads_broker/entry_points.py b/cads_broker/entry_points.py index e09dc7b1..4c6e9e77 100644 --- a/cads_broker/entry_points.py +++ b/cads_broker/entry_points.py @@ -10,7 +10,6 @@ import sqlalchemy as sa import typer -from cads_worker import utils from typing_extensions import Annotated from cads_broker import config, database, dispatcher, object_storage @@ -246,7 +245,7 @@ def init_db(connection_string: Optional[str] = None, force: bool = False) -> Non "aws_access_key_id": os.environ["STORAGE_ADMIN"], "aws_secret_access_key": os.environ["STORAGE_PASSWORD"], } - download_buckets: List[str] = utils.parse_data_volumes_config() + download_buckets: List[str] = object_storage.parse_data_volumes_config() for download_bucket in download_buckets: object_storage.create_download_bucket( download_bucket, object_storage_url, **storage_kws diff --git a/cads_broker/object_storage.py b/cads_broker/object_storage.py index 830573b0..36717955 100644 --- a/cads_broker/object_storage.py +++ b/cads_broker/object_storage.py @@ -1,5 +1,6 @@ """utility module to interface to the object storage.""" +import os.path import urllib.parse from typing import Any @@ -10,6 +11,18 @@ logger: structlog.stdlib.BoundLogger = structlog.get_logger(__name__) +def parse_data_volumes_config(path: str | None = None) -> list[str]: + if path is None: + path = os.environ["DATA_VOLUMES_CONFIG"] + + data_volumes = [] + with open(path) as fp: + for line in fp: + if data_volume := os.path.expandvars(line.rstrip("\n")): + data_volumes.append(data_volume) + return data_volumes + + def is_bucket_existing(client: Any, bucket_name: str) -> bool | None: """Return True if the bucket exists.""" try: From 86eef03e6bbd1af3eeaa395033db824f9bce9024 Mon Sep 17 00:00:00 2001 From: Francesco Nazzaro Date: Thu, 16 Jan 2025 14:01:07 +0100 Subject: [PATCH 3/9] kill job on workers --- cads_broker/database.py | 1 + cads_broker/dispatcher.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/cads_broker/database.py b/cads_broker/database.py index 277fcaf8..c1960af7 100644 --- a/cads_broker/database.py +++ b/cads_broker/database.py @@ -779,6 +779,7 @@ def logger_kwargs(request: SystemRequest) -> dict[str, str]: if event.event_type == "worker_name" ], "origin": request.origin, + "cache_id": request.cache_id, "portal": request.portal, "entry_point": request.entry_point, "request_metadata": request.request_metadata, diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index 1b1a461d..f12e9526 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -7,6 +7,7 @@ import time import traceback from typing import Any +import signal import attrs import cachetools @@ -92,6 +93,17 @@ def get_tasks_on_scheduler(dask_scheduler: distributed.Scheduler) -> dict[str, A return client.run_on_scheduler(get_tasks_on_scheduler) +def kill_job_on_worker(client: distributed.Client, request_uid: str) -> None: + worker_pid_event = client.get_events(request_uid)[0][1] + client.run( + os.kill, + worker_pid_event["pid"], + signal.SIGTERM, + workers=[worker_pid_event["worker"]], + nanny=True, + ) + + def cancel_jobs_on_scheduler(client: distributed.Client, job_ids: list[str]) -> None: """Cancel jobs on the dask scheduler. @@ -420,6 +432,7 @@ def sync_database(self, session: sa.orm.Session) -> None: for request in dismissed_requests: if future := self.futures.pop(request.request_uid, None): future.cancel() + kill_job_on_worker(self.client, request.request_uid) else: # if the request is not in the futures, it means that the request has been lost by the broker # try to cancel the job directly on the scheduler From af93d20a1ab7732bb611b8c7f7bef6fccf39f903 Mon Sep 17 00:00:00 2001 From: Francesco Nazzaro Date: Thu, 16 Jan 2025 14:31:45 +0100 Subject: [PATCH 4/9] killed job --- cads_broker/dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index f12e9526..102d90c8 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -432,11 +432,11 @@ def sync_database(self, session: sa.orm.Session) -> None: for request in dismissed_requests: if future := self.futures.pop(request.request_uid, None): future.cancel() - kill_job_on_worker(self.client, request.request_uid) else: # if the request is not in the futures, it means that the request has been lost by the broker # try to cancel the job directly on the scheduler cancel_jobs_on_scheduler(self.client, job_ids=[request.request_uid]) + kill_job_on_worker(self.client, request.request_uid) session = self.manage_dismissed_request(request, session) session.commit() From 3b3b8f2d5e9984077c993d9ea405b1897ffaa1b3 Mon Sep 17 00:00:00 2001 From: Francesco Nazzaro Date: Thu, 16 Jan 2025 14:54:19 +0100 Subject: [PATCH 5/9] qa --- cads_broker/dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index 102d90c8..7cb3be7b 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -3,11 +3,11 @@ import io import os import pickle +import signal import threading import time import traceback from typing import Any -import signal import attrs import cachetools From 2c2f8bfacdd8c0c2d4986cd7e46c50c596044a52 Mon Sep 17 00:00:00 2001 From: Francesco Nazzaro Date: Thu, 16 Jan 2025 15:45:49 +0100 Subject: [PATCH 6/9] Enhance job termination process on workers to handle multiple processes and improve error logging --- cads_broker/dispatcher.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index 7cb3be7b..3e3eedba 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -94,14 +94,23 @@ def get_tasks_on_scheduler(dask_scheduler: distributed.Scheduler) -> dict[str, A def kill_job_on_worker(client: distributed.Client, request_uid: str) -> None: - worker_pid_event = client.get_events(request_uid)[0][1] - client.run( - os.kill, - worker_pid_event["pid"], - signal.SIGTERM, - workers=[worker_pid_event["worker"]], - nanny=True, - ) + """Kill the job on the worker.""" + # loop on all the processes related to the request_uid + for worker_pid_event in client.get_events(request_uid): + _, worker_pid_event = worker_pid_event + pid = worker_pid_event["pid"] + worker_ip = worker_pid_event["worker"] + try: + client.run( + os.kill, + pid, + signal.SIGTERM, + workers=[worker_ip], + nanny=True, + ) + logger.info("killing worker", job_id=request_uid, pid=pid, worker_ip=worker_ip) + except (KeyError, NameError): + logger.warning("worker not found", job_id=request_uid, pid=pid, worker_ip=worker_ip) def cancel_jobs_on_scheduler(client: distributed.Client, job_ids: list[str]) -> None: From 694fc1f9bca4e2a9ecf187b8ae433c10b03fd4ad Mon Sep 17 00:00:00 2001 From: Francesco Nazzaro Date: Thu, 16 Jan 2025 15:54:01 +0100 Subject: [PATCH 7/9] Refactor logging in job termination process for improved readability --- cads_broker/dispatcher.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index 3e3eedba..f4a8b415 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -108,9 +108,13 @@ def kill_job_on_worker(client: distributed.Client, request_uid: str) -> None: workers=[worker_ip], nanny=True, ) - logger.info("killing worker", job_id=request_uid, pid=pid, worker_ip=worker_ip) + logger.info( + "killing worker", job_id=request_uid, pid=pid, worker_ip=worker_ip + ) except (KeyError, NameError): - logger.warning("worker not found", job_id=request_uid, pid=pid, worker_ip=worker_ip) + logger.warning( + "worker not found", job_id=request_uid, pid=pid, worker_ip=worker_ip + ) def cancel_jobs_on_scheduler(client: distributed.Client, job_ids: list[str]) -> None: From dcf0bd78d5708a2ad7cf3c4e2673e91c71e891e4 Mon Sep 17 00:00:00 2001 From: Francesco Nazzaro Date: Thu, 16 Jan 2025 15:59:37 +0100 Subject: [PATCH 8/9] Refactor logging in job termination process for improved readability --- cads_broker/dispatcher.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index 3e3eedba..f4a8b415 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -108,9 +108,13 @@ def kill_job_on_worker(client: distributed.Client, request_uid: str) -> None: workers=[worker_ip], nanny=True, ) - logger.info("killing worker", job_id=request_uid, pid=pid, worker_ip=worker_ip) + logger.info( + "killing worker", job_id=request_uid, pid=pid, worker_ip=worker_ip + ) except (KeyError, NameError): - logger.warning("worker not found", job_id=request_uid, pid=pid, worker_ip=worker_ip) + logger.warning( + "worker not found", job_id=request_uid, pid=pid, worker_ip=worker_ip + ) def cancel_jobs_on_scheduler(client: distributed.Client, job_ids: list[str]) -> None: From 279e691e726507cea5513c14c92620042368b265 Mon Sep 17 00:00:00 2001 From: Francesco Nazzaro Date: Thu, 16 Jan 2025 16:53:01 +0100 Subject: [PATCH 9/9] improve log message --- cads_broker/dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index f4a8b415..82044686 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -109,7 +109,7 @@ def kill_job_on_worker(client: distributed.Client, request_uid: str) -> None: nanny=True, ) logger.info( - "killing worker", job_id=request_uid, pid=pid, worker_ip=worker_ip + "killed job on worker", job_id=request_uid, pid=pid, worker_ip=worker_ip ) except (KeyError, NameError): logger.warning(