Skip to content

Commit

Permalink
S3: list_objects_v2() should have a hashed NextContinuationToken (#7187)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers committed Jan 27, 2024
1 parent 1f1e0ca commit 5aa3cc9
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 46 deletions.
74 changes: 67 additions & 7 deletions moto/s3/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,6 +1613,7 @@ def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.buckets: Dict[str, FakeBucket] = {}
self.tagger = TaggingService()
self._pagination_tokens: Dict[str, str] = {}

def reset(self) -> None:
# For every key and multipart, Moto opens a TemporaryFile to write the value of those keys
Expand Down Expand Up @@ -2442,8 +2443,13 @@ def upload_part_copy(
return multipart.set_part(part_id, src_value)

def list_objects(
self, bucket: FakeBucket, prefix: Optional[str], delimiter: Optional[str]
) -> Tuple[Set[FakeKey], Set[str]]:
self,
bucket: FakeBucket,
prefix: Optional[str],
delimiter: Optional[str],
marker: Optional[str],
max_keys: int,
) -> Tuple[Set[FakeKey], Set[str], bool, Optional[str]]:
key_results = set()
folder_results = set()
if prefix:
Expand Down Expand Up @@ -2474,16 +2480,70 @@ def list_objects(
folder_name for folder_name in sorted(folder_results, key=lambda key: key)
]

return key_results, folder_results
if marker:
limit = self._pagination_tokens.get(marker)
key_results = self._get_results_from_token(key_results, limit)

key_results, is_truncated, next_marker = self._truncate_result(
key_results, max_keys
)

return key_results, folder_results, is_truncated, next_marker

def list_objects_v2(
self, bucket: FakeBucket, prefix: Optional[str], delimiter: Optional[str]
) -> Set[Union[FakeKey, str]]:
result_keys, result_folders = self.list_objects(bucket, prefix, delimiter)
self,
bucket: FakeBucket,
prefix: Optional[str],
delimiter: Optional[str],
continuation_token: Optional[str],
start_after: Optional[str],
max_keys: int,
) -> Tuple[Set[Union[FakeKey, str]], bool, Optional[str]]:
result_keys, result_folders, _, _ = self.list_objects(
bucket, prefix, delimiter, marker=None, max_keys=1000
)
# sort the combination of folders and keys into lexicographical order
all_keys = result_keys + result_folders # type: ignore
all_keys.sort(key=self._get_name)
return all_keys

if continuation_token or start_after:
limit = (
self._pagination_tokens.get(continuation_token)
if continuation_token
else start_after
)
all_keys = self._get_results_from_token(all_keys, limit)

truncated_keys, is_truncated, next_continuation_token = self._truncate_result(
all_keys, max_keys
)

return truncated_keys, is_truncated, next_continuation_token

def _get_results_from_token(self, result_keys: Any, token: Any) -> Any:
continuation_index = 0
for key in result_keys:
if (key.name if isinstance(key, FakeKey) else key) > token:
break
continuation_index += 1
return result_keys[continuation_index:]

def _truncate_result(self, result_keys: Any, max_keys: int) -> Any:
if max_keys == 0:
result_keys = []
is_truncated = True
next_continuation_token = None
elif len(result_keys) > max_keys:
is_truncated = "true" # type: ignore
result_keys = result_keys[:max_keys]
item = result_keys[-1]
key_id = item.name if isinstance(item, FakeKey) else item
next_continuation_token = md5_hash(key_id.encode("utf-8")).hexdigest()
self._pagination_tokens[next_continuation_token] = key_id
else:
is_truncated = "false" # type: ignore
next_continuation_token = None
return result_keys, is_truncated, next_continuation_token

@staticmethod
def _get_name(key: Union[str, FakeKey]) -> str:
Expand Down
61 changes: 23 additions & 38 deletions moto/s3/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,16 +690,19 @@ def _bucket_response_get(
delimiter = querystring.get("delimiter", [None])[0]
max_keys = int(querystring.get("max-keys", [1000])[0])
marker = querystring.get("marker", [None])[0]
result_keys, result_folders = self.backend.list_objects(
bucket, prefix, delimiter
)
encoding_type = querystring.get("encoding-type", [None])[0]

if marker:
result_keys = self._get_results_from_token(result_keys, marker)

result_keys, is_truncated, next_marker = self._truncate_result(
result_keys, max_keys
(
result_keys,
result_folders,
is_truncated,
next_marker,
) = self.backend.list_objects(
bucket=bucket,
prefix=prefix,
delimiter=delimiter,
marker=marker,
max_keys=max_keys,
)

template = self.response_template(S3_BUCKET_GET_RESPONSE)
Expand Down Expand Up @@ -746,20 +749,25 @@ def _handle_list_objects_v2(
if prefix and isinstance(prefix, bytes):
prefix = prefix.decode("utf-8")
delimiter = querystring.get("delimiter", [None])[0]
all_keys = self.backend.list_objects_v2(bucket, prefix, delimiter)

fetch_owner = querystring.get("fetch-owner", [False])[0]
max_keys = int(querystring.get("max-keys", [1000])[0])
start_after = querystring.get("start-after", [None])[0]
encoding_type = querystring.get("encoding-type", [None])[0]

if continuation_token or start_after:
limit = continuation_token or start_after
all_keys = self._get_results_from_token(all_keys, limit)

truncated_keys, is_truncated, next_continuation_token = self._truncate_result(
all_keys, max_keys
(
truncated_keys,
is_truncated,
next_continuation_token,
) = self.backend.list_objects_v2(
bucket=bucket,
prefix=prefix,
delimiter=delimiter,
continuation_token=continuation_token,
start_after=start_after,
max_keys=max_keys,
)

result_keys, result_folders = self._split_truncated_keys(truncated_keys)

key_count = len(result_keys) + len(result_folders)
Expand Down Expand Up @@ -796,29 +804,6 @@ def _split_truncated_keys(truncated_keys: Any) -> Any: # type: ignore[misc]
result_folders.append(key)
return result_keys, result_folders

def _get_results_from_token(self, result_keys: Any, token: Any) -> Any:
continuation_index = 0
for key in result_keys:
if (key.name if isinstance(key, FakeKey) else key) > token:
break
continuation_index += 1
return result_keys[continuation_index:]

def _truncate_result(self, result_keys: Any, max_keys: int) -> Any:
if max_keys == 0:
result_keys = []
is_truncated = True
next_continuation_token = None
elif len(result_keys) > max_keys:
is_truncated = "true" # type: ignore
result_keys = result_keys[:max_keys]
item = result_keys[-1]
next_continuation_token = item.name if isinstance(item, FakeKey) else item
else:
is_truncated = "false" # type: ignore
next_continuation_token = None
return result_keys, is_truncated, next_continuation_token

def _body_contains_location_constraint(self, body: bytes) -> bool:
if body:
try:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_s3/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,7 @@ def test_list_objects_v2_truncate_combined_keys_and_folders():
assert len(resp["CommonPrefixes"]) == 1
assert resp["CommonPrefixes"][0]["Prefix"] == "1/"

last_tail = resp["NextContinuationToken"]
last_tail = resp["Contents"][-1]["Key"]
resp = s3_client.list_objects_v2(
Bucket="mybucket", MaxKeys=2, Prefix="", Delimiter="/", StartAfter=last_tail
)
Expand Down

0 comments on commit 5aa3cc9

Please sign in to comment.