diff --git a/ch_backup/backup/layout.py b/ch_backup/backup/layout.py index f15db8c1..030f1b92 100644 --- a/ch_backup/backup/layout.py +++ b/ch_backup/backup/layout.py @@ -4,7 +4,7 @@ import os from pathlib import Path -from typing import Callable, List, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence from urllib.parse import quote from nacl.exceptions import CryptoError @@ -248,7 +248,9 @@ def get_udf_create_statement( """ Download user defined function create statement. """ - remote_path = _udf_data_path(backup_meta.path, filename) + remote_path = self._get_escaped_if_exists( + _udf_data_path, backup_meta.path, filename + ) return self._storage_loader.download_data(remote_path, encryption=True) def get_local_nc_create_statement(self, nc_name: str) -> Optional[str]: @@ -436,8 +438,12 @@ def download_data_part( os.makedirs(fs_part_path, exist_ok=True) - remote_dir_path = _part_path( - part.link or backup_meta.path, part.database, part.table, part.name + remote_dir_path = self._get_escaped_if_exists( + _part_path, + part.link or backup_meta.path, + part.database, + part.table, + part.name, ) if part.tarball: @@ -474,8 +480,12 @@ def check_data_part(self, backup_path: str, part: PartMetadata) -> bool: Check availability of part data in storage. """ try: - remote_dir_path = _part_path( - part.link or backup_path, part.database, part.table, part.name + remote_dir_path = self._get_escaped_if_exists( + _part_path, + part.link or backup_path, + part.database, + part.table, + part.name, ) remote_files = self._storage_loader.list_dir(remote_dir_path) @@ -560,8 +570,12 @@ def delete_data_parts( deleting_files: List[str] = [] for part in parts: - part_path = _part_path( - part.link or backup_meta.path, part.database, part.table, part.name + part_path = self._get_escaped_if_exists( + _part_path, + part.link or backup_meta.path, + part.database, + part.table, + part.name, ) logging.debug("Deleting data part {}", part_path) if part.tarball: @@ -615,6 +629,17 @@ def _target_part_size(self, part: PartMetadata) -> int: tar_size, self._encryption_chunk_size, self._encryption_metadata_size ) + def _get_escaped_if_exists( + self, path_function: Callable, *args: Any, **kwargs: Any + ) -> str: + """ + Return escaped path if it exists. Otherwise return regular path. + """ + path = path_function(*args, escape_names=True, **kwargs) + if self._storage_loader.path_exists(path, is_dir=True): + return path + return path_function(*args, escape_names=False, **kwargs) + def _access_control_data_path(backup_path: str, file_name: str) -> str: """ @@ -623,10 +648,12 @@ def _access_control_data_path(backup_path: str, file_name: str) -> str: return os.path.join(backup_path, "access_control", file_name) -def _udf_data_path(backup_path: str, udf_file: str) -> str: +def _udf_data_path(backup_path: str, udf_file: str, escape_names: bool = True) -> str: """ Return S3 path to UDF data """ + if escape_names: + return os.path.join(backup_path, "udf", _quote(udf_file)) return os.path.join(backup_path, "udf", udf_file) @@ -653,10 +680,20 @@ def _named_collections_data_path(backup_path: str, nc_name: str) -> str: return os.path.join(backup_path, "named_collections", _quote(nc_name) + ".sql") -def _part_path(backup_path: str, db_name: str, table_name: str, part_name: str) -> str: +def _part_path( + backup_path: str, + db_name: str, + table_name: str, + part_name: str, + escape_names: bool = True, +) -> str: """ Return S3 path to data part. """ + if escape_names: + return os.path.join( + backup_path, "data", _quote(db_name), _quote(table_name), part_name + ) return os.path.join(backup_path, "data", db_name, table_name, part_name) diff --git a/ch_backup/storage/engine/base.py b/ch_backup/storage/engine/base.py index 5a765d39..b11b023a 100644 --- a/ch_backup/storage/engine/base.py +++ b/ch_backup/storage/engine/base.py @@ -55,7 +55,7 @@ def list_dir( pass @abstractmethod - def path_exists(self, remote_path: str) -> bool: + def path_exists(self, remote_path: str, is_dir: bool = False) -> bool: """ Check if remote path exists. """ diff --git a/ch_backup/storage/engine/s3/s3_engine.py b/ch_backup/storage/engine/s3/s3_engine.py index cc0fd6a1..a8d5871e 100644 --- a/ch_backup/storage/engine/s3/s3_engine.py +++ b/ch_backup/storage/engine/s3/s3_engine.py @@ -142,10 +142,12 @@ def list_dir( return contents - def path_exists(self, remote_path: str) -> bool: + def path_exists(self, remote_path: str, is_dir: bool = False) -> bool: """ Check if remote path exists. """ + if is_dir: + return self._directory_exists(remote_path) try: self._s3_client.head_object(Bucket=self._s3_bucket_name, Key=remote_path) return True @@ -155,6 +157,24 @@ def path_exists(self, remote_path: str) -> bool: return False raise ce + def _directory_exists(self, remote_path: str) -> bool: + """ + Check if remote directory exists. + """ + remote_path = remote_path.rstrip("/") + resp = self._s3_client.list_objects( + Bucket=self._s3_bucket_name, Prefix=remote_path, Delimiter="/" + ) + # CommonPrefixes contains all (if there are any) keys between + # Prefix and the next occurrence of the string specified by the delimiter. + # If there are more than 1000 keys which satisfy given Prefix this may not work, + # but there should be almost never more than one such key in our cases. + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/list_objects.html + for prefix in resp.get("CommonPrefixes", []): + if prefix["Prefix"].rstrip("/") == remote_path: + return True + return False + def create_multipart_upload(self, remote_path: str) -> str: return self._multipart_uploader.create_multipart_upload(remote_path) diff --git a/ch_backup/storage/loader.py b/ch_backup/storage/loader.py index af06f91e..b410e871 100644 --- a/ch_backup/storage/loader.py +++ b/ch_backup/storage/loader.py @@ -203,11 +203,11 @@ def list_dir( remote_path, recursive=recursive, absolute=absolute ) - def path_exists(self, remote_path: str) -> bool: + def path_exists(self, remote_path: str, is_dir: bool = False) -> bool: """ Check whether a remote path exists or not. """ - return self._engine.path_exists(remote_path) + return self._engine.path_exists(remote_path, is_dir) def get_file_size(self, remote_path: str) -> int: """ diff --git a/tests/integration/modules/ch_backup.py b/tests/integration/modules/ch_backup.py index 11879b80..4666f7dc 100644 --- a/tests/integration/modules/ch_backup.py +++ b/tests/integration/modules/ch_backup.py @@ -6,6 +6,7 @@ import os from copy import copy from typing import Sequence, Set, Union +from urllib.parse import quote import yaml @@ -172,8 +173,8 @@ def get_file_paths(self) -> Sequence[str]: part_path = os.path.join( part_obj.get("link") or backup_path, "data", - db_name, - table_name, + _quote(db_name), + _quote(table_name), part_name, ) if part_obj.get("tarball", False): @@ -443,3 +444,12 @@ def get_version() -> str: """ with open("ch_backup/version.txt", encoding="utf-8") as f: return f.read().strip() + + +def _quote(value: str) -> str: + return quote(value, safe="").translate( + { + ord("."): "%2E", + ord("-"): "%2D", + } + )