Skip to content

Commit

Permalink
Escape database, table and udf names for S3 paths
Browse files Browse the repository at this point in the history
  • Loading branch information
kirillgarbar committed Jul 1, 2024
1 parent ae46512 commit df857ff
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 16 deletions.
57 changes: 47 additions & 10 deletions ch_backup/backup/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from pathlib import Path
from typing import Callable, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence
from urllib.parse import quote

from ch_backup import logging
Expand Down Expand Up @@ -246,7 +246,9 @@ def get_udf_create_statement(
"""
Download user defined function create statement.
"""
remote_path = _udf_data_path(backup_meta.path, filename)
remote_path = self._try_get_escaped_path(
_udf_data_path, backup_meta.path, filename
)
return self._storage_loader.download_data(remote_path, encryption=True)

def get_local_nc_create_statement(self, nc_name: str) -> Optional[str]:
Expand Down Expand Up @@ -419,8 +421,12 @@ def download_data_part(

os.makedirs(fs_part_path, exist_ok=True)

remote_dir_path = _part_path(
part.link or backup_meta.path, part.database, part.table, part.name
remote_dir_path = self._try_get_escaped_path(
_part_path,
part.link or backup_meta.path,
part.database,
part.table,
part.name,
)

if part.tarball:
Expand Down Expand Up @@ -457,8 +463,12 @@ def check_data_part(self, backup_path: str, part: PartMetadata) -> bool:
Check availability of part data in storage.
"""
try:
remote_dir_path = _part_path(
part.link or backup_path, part.database, part.table, part.name
remote_dir_path = self._try_get_escaped_path(
_part_path,
part.link or backup_path,
part.database,
part.table,
part.name,
)
remote_files = self._storage_loader.list_dir(remote_dir_path)

Expand Down Expand Up @@ -543,8 +553,12 @@ def delete_data_parts(

deleting_files: List[str] = []
for part in parts:
part_path = _part_path(
part.link or backup_meta.path, part.database, part.table, part.name
part_path = self._try_get_escaped_path(
_part_path,
part.link or backup_meta.path,
part.database,
part.table,
part.name,
)
logging.debug("Deleting data part {}", part_path)
if part.tarball:
Expand Down Expand Up @@ -598,6 +612,17 @@ def _target_part_size(self, part: PartMetadata) -> int:
tar_size, self._encryption_chunk_size, self._encryption_metadata_size
)

def _try_get_escaped_path(
self, path_function: Callable, *args: Any, **kwargs: Any
) -> str:
"""
Return escaped path if if exists. Otherwise return regular path.
"""
path = path_function(*args, escape_names=True, **kwargs)
if self._storage_loader.path_exists(path, is_dir=True):
return path
return path_function(*args, escape_names=False, **kwargs)


def _access_control_data_path(backup_path: str, file_name: str) -> str:
"""
Expand All @@ -606,10 +631,12 @@ def _access_control_data_path(backup_path: str, file_name: str) -> str:
return os.path.join(backup_path, "access_control", file_name)


def _udf_data_path(backup_path: str, udf_file: str) -> str:
def _udf_data_path(backup_path: str, udf_file: str, escape_names: bool = True) -> str:
"""
Return S3 path to UDF data
"""
if escape_names:
return os.path.join(backup_path, "udf", _quote(udf_file))
return os.path.join(backup_path, "udf", udf_file)


Expand All @@ -636,10 +663,20 @@ def _named_collections_data_path(backup_path: str, nc_name: str) -> str:
return os.path.join(backup_path, "named_collections", _quote(nc_name) + ".sql")


def _part_path(backup_path: str, db_name: str, table_name: str, part_name: str) -> str:
def _part_path(
backup_path: str,
db_name: str,
table_name: str,
part_name: str,
escape_names: bool = True,
) -> str:
"""
Return S3 path to data part.
"""
if escape_names:
return os.path.join(
backup_path, "data", _quote(db_name), _quote(table_name), part_name
)
return os.path.join(backup_path, "data", db_name, table_name, part_name)


Expand Down
2 changes: 1 addition & 1 deletion ch_backup/storage/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def list_dir(
pass

@abstractmethod
def path_exists(self, remote_path: str) -> bool:
def path_exists(self, remote_path: str, is_dir: bool = False) -> bool:
"""
Check if remote path exists.
"""
Expand Down
22 changes: 21 additions & 1 deletion ch_backup/storage/engine/s3/s3_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,12 @@ def list_dir(

return contents

def path_exists(self, remote_path: str) -> bool:
def path_exists(self, remote_path: str, is_dir: bool = False) -> bool:
"""
Check if remote path exists.
"""
if is_dir:
return self._directory_exists(remote_path)
try:
self._s3_client.head_object(Bucket=self._s3_bucket_name, Key=remote_path)
return True
Expand All @@ -155,6 +157,24 @@ def path_exists(self, remote_path: str) -> bool:
return False
raise ce

def _directory_exists(self, remote_path: str) -> bool:
"""
Check if remote directory exists.
"""
remote_path = remote_path.rstrip("/")
resp = self._s3_client.list_objects(
Bucket=self._s3_bucket_name, Prefix=remote_path, Delimiter="/"
)
# CommonPrefixes contains all (if there are any) keys between
# Prefix and the next occurrence of the string specified by the delimiter.
# If there are more than 1000 keys which satisfy given Prefix this may not work,
# but there should be almost never more than one such key in our cases.
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/list_objects.html
for prefix in resp.get("CommonPrefixes", []):
if prefix["Prefix"].rstrip("/") == remote_path:
return True
return False

def create_multipart_upload(self, remote_path: str) -> str:
return self._multipart_uploader.create_multipart_upload(remote_path)

Expand Down
4 changes: 2 additions & 2 deletions ch_backup/storage/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ def list_dir(
remote_path, recursive=recursive, absolute=absolute
)

def path_exists(self, remote_path: str) -> bool:
def path_exists(self, remote_path: str, is_dir: bool = False) -> bool:
"""
Check whether a remote path exists or not.
"""
return self._engine.path_exists(remote_path)
return self._engine.path_exists(remote_path, is_dir)

def get_file_size(self, remote_path: str) -> int:
"""
Expand Down
14 changes: 12 additions & 2 deletions tests/integration/modules/ch_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from copy import copy
from typing import Sequence, Set, Union
from urllib.parse import quote

import yaml

Expand Down Expand Up @@ -172,8 +173,8 @@ def get_file_paths(self) -> Sequence[str]:
part_path = os.path.join(
part_obj.get("link") or backup_path,
"data",
db_name,
table_name,
_quote(db_name),
_quote(table_name),
part_name,
)
if part_obj.get("tarball", False):
Expand Down Expand Up @@ -443,3 +444,12 @@ def get_version() -> str:
"""
with open("ch_backup/version.txt", encoding="utf-8") as f:
return f.read().strip()


def _quote(value: str) -> str:
return quote(value, safe="").translate(
{
ord("."): "%2E",
ord("-"): "%2D",
}
)

0 comments on commit df857ff

Please sign in to comment.