diff --git a/cads_worker/entry_points.py b/cads_worker/entry_points.py index b62a135..874a76e 100644 --- a/cads_worker/entry_points.py +++ b/cads_worker/entry_points.py @@ -28,10 +28,10 @@ class CleanerKwargs(TypedDict): delete_unknown_files: bool lock_validity_period: float use_database: bool + depth: int def _cache_cleaner() -> None: - cache_bucket = os.environ.get("CACHE_BUCKET", None) use_database = strtobool(os.environ.get("USE_DATABASE", "1")) cleaner_kwargs = CleanerKwargs( maxsize=int(os.environ.get("MAX_SIZE", 1_000_000_000)), @@ -39,13 +39,20 @@ def _cache_cleaner() -> None: delete_unknown_files=not use_database, lock_validity_period=float(os.environ.get("LOCK_VALIDITY_PERIOD", 86400)), use_database=use_database, + depth=int(os.getenv("CACHE_DEPTH", 2)), ) - LOGGER.info("Running cache cleaner", cache_bucket=cache_bucket, **cleaner_kwargs) - try: - cacholote.clean_cache_files(**cleaner_kwargs) - except Exception: - LOGGER.exception("cache_cleaner crashed") - raise + for cache_files_urlpath in utils.parse_data_volumes_config(): + cacholote.config.set(cache_files_urlpath=cache_files_urlpath) + LOGGER.info( + "Running cache cleaner", + cache_files_urlpath=cache_files_urlpath, + **cleaner_kwargs, + ) + try: + cacholote.clean_cache_files(**cleaner_kwargs) + except Exception: + LOGGER.exception("cache_cleaner crashed") + raise def _add_tzinfo(timestamp: datetime.datetime) -> datetime.datetime: diff --git a/cads_worker/utils.py b/cads_worker/utils.py index 7f57682..bbe6801 100644 --- a/cads_worker/utils.py +++ b/cads_worker/utils.py @@ -1,4 +1,8 @@ +import contextlib import os +import pathlib +import tempfile +from collections.abc import Iterator def parse_data_volumes_config(path: str | None = None) -> list[str]: @@ -7,3 +11,25 @@ def parse_data_volumes_config(path: str | None = None) -> list[str]: with open(path) as fp: return [os.path.expandvars(line.rstrip("\n")) for line in fp] + + +@contextlib.contextmanager +def enter_tmp_working_dir() -> Iterator[str]: + old_cwd = os.getcwd() + with tempfile.TemporaryDirectory() as tmpdir: + os.chdir(tmpdir) + try: + yield os.getcwd() + finally: + os.chdir(old_cwd) + + +@contextlib.contextmanager +def make_cache_tmp_path(base_dir: str) -> Iterator[pathlib.Path]: + with tempfile.TemporaryDirectory(dir=base_dir) as tmpdir: + cache_tmp_path = pathlib.Path(tmpdir) + cache_tmp_path.with_suffix(".lock").touch() + try: + yield cache_tmp_path + finally: + cache_tmp_path.with_suffix(".lock").unlink(missing_ok=True) diff --git a/cads_worker/worker.py b/cads_worker/worker.py index 7942d86..4720a71 100644 --- a/cads_worker/worker.py +++ b/cads_worker/worker.py @@ -3,7 +3,6 @@ import os import random import socket -import tempfile from typing import Any import cacholote @@ -205,36 +204,41 @@ def submit_workflow( context.warn(f"CACHE_DETPH={depth} is not supported.") logger.info("Processing job", job_id=job_id) + collection_id = config.get("collection_id") cacholote.config.set( logger=LOGGER, cache_files_urlpath=cache_files_urlpath, sessionmaker=context.session_maker, context=context, + tag=collection_id, ) + fs, dirname = cacholote.utils.get_cache_files_fs_dirname() + adaptor_class = cads_adaptors.get_adaptor_class(entry_point, setup_code) - adaptor = adaptor_class(form=form, context=context, **config) - collection_id = config.get("collection_id") - cwd = os.getcwd() - with tempfile.TemporaryDirectory() as tmpdir: - os.chdir(tmpdir) - try: - request = {k: request[k] for k in sorted(request.keys())} - with cacholote.config.set(tag=collection_id): - result = cacholote.cacheable( - adaptor.retrieve, collection_id=collection_id - )(request=request) - except Exception as err: - logger.exception(job_id=job_id, event_type="EXCEPTION") - context.add_user_visible_error( - f"The job failed with: {err.__class__.__name__}" - ) - context.error(f"{err.__class__.__name__}: {str(err)}") - raise - finally: - os.chdir(cwd) - fs, _ = cacholote.utils.get_cache_files_fs_dirname() - if (local_path := result.result["args"][0]["file:local_path"]).startswith("s3://"): - fs.chmod(local_path, acl="public-read") + try: + with utils.enter_tmp_working_dir() as working_dir: + base_dir = dirname if "file" in fs.protocol else working_dir + with utils.make_cache_tmp_path(base_dir) as cache_tmp_path: + adaptor = adaptor_class( + form=form, + context=context, + cache_tmp_path=cache_tmp_path, + **config, + ) + request = {k: request[k] for k in sorted(request.keys())} + cached_retrieve = cacholote.cacheable( + adaptor.retrieve, + collection_id=collection_id, + ) + result = cached_retrieve(request=request) + except Exception as err: + logger.exception(job_id=job_id, event_type="EXCEPTION") + context.add_user_visible_error(f"The job failed with: {err.__class__.__name__}") + context.error(f"{err.__class__.__name__}: {str(err)}") + raise + + if "s3" in fs.protocol: + fs.chmod(result.result["args"][0]["file:local_path"], acl="public-read") with context.session_maker() as session: request = cads_broker.database.set_request_cache_id( request_uid=job_id, diff --git a/tests/test_30_utils.py b/tests/test_30_utils.py index fa856c3..cee54e0 100644 --- a/tests/test_30_utils.py +++ b/tests/test_30_utils.py @@ -1,4 +1,6 @@ +import os import pathlib +import tempfile import pytest @@ -17,3 +19,20 @@ def test_utils_parse_data_volumes_config( monkeypatch.setenv("DATA_VOLUMES_CONFIG", str(data_volumes_config)) assert utils.parse_data_volumes_config(None) == ["foo", "bar"] + + +def test_utils_enter_tmp_working_dir() -> None: + with utils.enter_tmp_working_dir() as tmp_working_dir: + assert os.getcwd() == tmp_working_dir + assert os.path.dirname(tmp_working_dir) == os.path.realpath( + tempfile.gettempdir() + ) + assert not os.path.exists(tmp_working_dir) + + +def test_utils_make_cache_tmp_path(tmp_path: pathlib.Path) -> None: + with utils.make_cache_tmp_path(str(tmp_path)) as cache_tmp_path: + assert cache_tmp_path.parent == tmp_path + assert cache_tmp_path.with_suffix(".lock").exists() + assert not cache_tmp_path.exists() + assert not cache_tmp_path.with_suffix(".lock").exists()