Skip to content

Commit

Permalink
Escape database, table and udf names for S3 paths (#168)
Browse files Browse the repository at this point in the history
* Escape database, table and udf names for S3 paths

* Comment typo

* Change get escaped path method

* Revert "Change get escaped path method"

This reverts commit 431a29f.

* Rename get escaped path method
  • Loading branch information
kirillgarbar authored Jul 22, 2024
1 parent de0178c commit 23e9954
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 16 deletions.
57 changes: 47 additions & 10 deletions ch_backup/backup/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from pathlib import Path
from typing import Callable, List, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence
from urllib.parse import quote

from nacl.exceptions import CryptoError
Expand Down Expand Up @@ -248,7 +248,9 @@ def get_udf_create_statement(
"""
Download user defined function create statement.
"""
remote_path = _udf_data_path(backup_meta.path, filename)
remote_path = self._get_escaped_if_exists(
_udf_data_path, backup_meta.path, filename
)
return self._storage_loader.download_data(remote_path, encryption=True)

def get_local_nc_create_statement(self, nc_name: str) -> Optional[str]:
Expand Down Expand Up @@ -436,8 +438,12 @@ def download_data_part(

os.makedirs(fs_part_path, exist_ok=True)

remote_dir_path = _part_path(
part.link or backup_meta.path, part.database, part.table, part.name
remote_dir_path = self._get_escaped_if_exists(
_part_path,
part.link or backup_meta.path,
part.database,
part.table,
part.name,
)

if part.tarball:
Expand Down Expand Up @@ -474,8 +480,12 @@ def check_data_part(self, backup_path: str, part: PartMetadata) -> bool:
Check availability of part data in storage.
"""
try:
remote_dir_path = _part_path(
part.link or backup_path, part.database, part.table, part.name
remote_dir_path = self._get_escaped_if_exists(
_part_path,
part.link or backup_path,
part.database,
part.table,
part.name,
)
remote_files = self._storage_loader.list_dir(remote_dir_path)

Expand Down Expand Up @@ -560,8 +570,12 @@ def delete_data_parts(

deleting_files: List[str] = []
for part in parts:
part_path = _part_path(
part.link or backup_meta.path, part.database, part.table, part.name
part_path = self._get_escaped_if_exists(
_part_path,
part.link or backup_meta.path,
part.database,
part.table,
part.name,
)
logging.debug("Deleting data part {}", part_path)
if part.tarball:
Expand Down Expand Up @@ -615,6 +629,17 @@ def _target_part_size(self, part: PartMetadata) -> int:
tar_size, self._encryption_chunk_size, self._encryption_metadata_size
)

def _get_escaped_if_exists(
self, path_function: Callable, *args: Any, **kwargs: Any
) -> str:
"""
Return escaped path if it exists. Otherwise return regular path.
"""
path = path_function(*args, escape_names=True, **kwargs)
if self._storage_loader.path_exists(path, is_dir=True):
return path
return path_function(*args, escape_names=False, **kwargs)


def _access_control_data_path(backup_path: str, file_name: str) -> str:
"""
Expand All @@ -623,10 +648,12 @@ def _access_control_data_path(backup_path: str, file_name: str) -> str:
return os.path.join(backup_path, "access_control", file_name)


def _udf_data_path(backup_path: str, udf_file: str) -> str:
def _udf_data_path(backup_path: str, udf_file: str, escape_names: bool = True) -> str:
"""
Return S3 path to UDF data
"""
if escape_names:
return os.path.join(backup_path, "udf", _quote(udf_file))
return os.path.join(backup_path, "udf", udf_file)


Expand All @@ -653,10 +680,20 @@ def _named_collections_data_path(backup_path: str, nc_name: str) -> str:
return os.path.join(backup_path, "named_collections", _quote(nc_name) + ".sql")


def _part_path(backup_path: str, db_name: str, table_name: str, part_name: str) -> str:
def _part_path(
backup_path: str,
db_name: str,
table_name: str,
part_name: str,
escape_names: bool = True,
) -> str:
"""
Return S3 path to data part.
"""
if escape_names:
return os.path.join(
backup_path, "data", _quote(db_name), _quote(table_name), part_name
)
return os.path.join(backup_path, "data", db_name, table_name, part_name)


Expand Down
2 changes: 1 addition & 1 deletion ch_backup/storage/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def list_dir(
pass

@abstractmethod
def path_exists(self, remote_path: str) -> bool:
def path_exists(self, remote_path: str, is_dir: bool = False) -> bool:
"""
Check if remote path exists.
"""
Expand Down
22 changes: 21 additions & 1 deletion ch_backup/storage/engine/s3/s3_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,12 @@ def list_dir(

return contents

def path_exists(self, remote_path: str) -> bool:
def path_exists(self, remote_path: str, is_dir: bool = False) -> bool:
"""
Check if remote path exists.
"""
if is_dir:
return self._directory_exists(remote_path)
try:
self._s3_client.head_object(Bucket=self._s3_bucket_name, Key=remote_path)
return True
Expand All @@ -155,6 +157,24 @@ def path_exists(self, remote_path: str) -> bool:
return False
raise ce

def _directory_exists(self, remote_path: str) -> bool:
"""
Check if remote directory exists.
"""
remote_path = remote_path.rstrip("/")
resp = self._s3_client.list_objects(
Bucket=self._s3_bucket_name, Prefix=remote_path, Delimiter="/"
)
# CommonPrefixes contains all (if there are any) keys between
# Prefix and the next occurrence of the string specified by the delimiter.
# If there are more than 1000 keys which satisfy given Prefix this may not work,
# but there should be almost never more than one such key in our cases.
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/list_objects.html
for prefix in resp.get("CommonPrefixes", []):
if prefix["Prefix"].rstrip("/") == remote_path:
return True
return False

def create_multipart_upload(self, remote_path: str) -> str:
return self._multipart_uploader.create_multipart_upload(remote_path)

Expand Down
4 changes: 2 additions & 2 deletions ch_backup/storage/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ def list_dir(
remote_path, recursive=recursive, absolute=absolute
)

def path_exists(self, remote_path: str) -> bool:
def path_exists(self, remote_path: str, is_dir: bool = False) -> bool:
"""
Check whether a remote path exists or not.
"""
return self._engine.path_exists(remote_path)
return self._engine.path_exists(remote_path, is_dir)

def get_file_size(self, remote_path: str) -> int:
"""
Expand Down
14 changes: 12 additions & 2 deletions tests/integration/modules/ch_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from copy import copy
from typing import Sequence, Set, Union
from urllib.parse import quote

import yaml

Expand Down Expand Up @@ -172,8 +173,8 @@ def get_file_paths(self) -> Sequence[str]:
part_path = os.path.join(
part_obj.get("link") or backup_path,
"data",
db_name,
table_name,
_quote(db_name),
_quote(table_name),
part_name,
)
if part_obj.get("tarball", False):
Expand Down Expand Up @@ -443,3 +444,12 @@ def get_version() -> str:
"""
with open("ch_backup/version.txt", encoding="utf-8") as f:
return f.read().strip()


def _quote(value: str) -> str:
return quote(value, safe="").translate(
{
ord("."): "%2E",
ord("-"): "%2D",
}
)

0 comments on commit 23e9954

Please sign in to comment.