From 5aa3cc9d735e32b7a6ae03dd1de25d94944d930f Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sun, 14 Jan 2024 13:02:33 +0000 Subject: [PATCH] S3: list_objects_v2() should have a hashed NextContinuationToken (#7187) --- moto/s3/models.py | 74 ++++++++++++++++++++++++++++++++++++---- moto/s3/responses.py | 61 +++++++++++++-------------------- tests/test_s3/test_s3.py | 2 +- 3 files changed, 91 insertions(+), 46 deletions(-) diff --git a/moto/s3/models.py b/moto/s3/models.py index a65f278b3b02..3837e7e5919a 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -1613,6 +1613,7 @@ def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self.buckets: Dict[str, FakeBucket] = {} self.tagger = TaggingService() + self._pagination_tokens: Dict[str, str] = {} def reset(self) -> None: # For every key and multipart, Moto opens a TemporaryFile to write the value of those keys @@ -2442,8 +2443,13 @@ def upload_part_copy( return multipart.set_part(part_id, src_value) def list_objects( - self, bucket: FakeBucket, prefix: Optional[str], delimiter: Optional[str] - ) -> Tuple[Set[FakeKey], Set[str]]: + self, + bucket: FakeBucket, + prefix: Optional[str], + delimiter: Optional[str], + marker: Optional[str], + max_keys: int, + ) -> Tuple[Set[FakeKey], Set[str], bool, Optional[str]]: key_results = set() folder_results = set() if prefix: @@ -2474,16 +2480,70 @@ def list_objects( folder_name for folder_name in sorted(folder_results, key=lambda key: key) ] - return key_results, folder_results + if marker: + limit = self._pagination_tokens.get(marker) + key_results = self._get_results_from_token(key_results, limit) + + key_results, is_truncated, next_marker = self._truncate_result( + key_results, max_keys + ) + + return key_results, folder_results, is_truncated, next_marker def list_objects_v2( - self, bucket: FakeBucket, prefix: Optional[str], delimiter: Optional[str] - ) -> Set[Union[FakeKey, str]]: - result_keys, result_folders = self.list_objects(bucket, prefix, delimiter) + self, + bucket: FakeBucket, + prefix: Optional[str], + delimiter: Optional[str], + continuation_token: Optional[str], + start_after: Optional[str], + max_keys: int, + ) -> Tuple[Set[Union[FakeKey, str]], bool, Optional[str]]: + result_keys, result_folders, _, _ = self.list_objects( + bucket, prefix, delimiter, marker=None, max_keys=1000 + ) # sort the combination of folders and keys into lexicographical order all_keys = result_keys + result_folders # type: ignore all_keys.sort(key=self._get_name) - return all_keys + + if continuation_token or start_after: + limit = ( + self._pagination_tokens.get(continuation_token) + if continuation_token + else start_after + ) + all_keys = self._get_results_from_token(all_keys, limit) + + truncated_keys, is_truncated, next_continuation_token = self._truncate_result( + all_keys, max_keys + ) + + return truncated_keys, is_truncated, next_continuation_token + + def _get_results_from_token(self, result_keys: Any, token: Any) -> Any: + continuation_index = 0 + for key in result_keys: + if (key.name if isinstance(key, FakeKey) else key) > token: + break + continuation_index += 1 + return result_keys[continuation_index:] + + def _truncate_result(self, result_keys: Any, max_keys: int) -> Any: + if max_keys == 0: + result_keys = [] + is_truncated = True + next_continuation_token = None + elif len(result_keys) > max_keys: + is_truncated = "true" # type: ignore + result_keys = result_keys[:max_keys] + item = result_keys[-1] + key_id = item.name if isinstance(item, FakeKey) else item + next_continuation_token = md5_hash(key_id.encode("utf-8")).hexdigest() + self._pagination_tokens[next_continuation_token] = key_id + else: + is_truncated = "false" # type: ignore + next_continuation_token = None + return result_keys, is_truncated, next_continuation_token @staticmethod def _get_name(key: Union[str, FakeKey]) -> str: diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 1b189521f4f7..bacea5426ab5 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -690,16 +690,19 @@ def _bucket_response_get( delimiter = querystring.get("delimiter", [None])[0] max_keys = int(querystring.get("max-keys", [1000])[0]) marker = querystring.get("marker", [None])[0] - result_keys, result_folders = self.backend.list_objects( - bucket, prefix, delimiter - ) encoding_type = querystring.get("encoding-type", [None])[0] - if marker: - result_keys = self._get_results_from_token(result_keys, marker) - - result_keys, is_truncated, next_marker = self._truncate_result( - result_keys, max_keys + ( + result_keys, + result_folders, + is_truncated, + next_marker, + ) = self.backend.list_objects( + bucket=bucket, + prefix=prefix, + delimiter=delimiter, + marker=marker, + max_keys=max_keys, ) template = self.response_template(S3_BUCKET_GET_RESPONSE) @@ -746,20 +749,25 @@ def _handle_list_objects_v2( if prefix and isinstance(prefix, bytes): prefix = prefix.decode("utf-8") delimiter = querystring.get("delimiter", [None])[0] - all_keys = self.backend.list_objects_v2(bucket, prefix, delimiter) fetch_owner = querystring.get("fetch-owner", [False])[0] max_keys = int(querystring.get("max-keys", [1000])[0]) start_after = querystring.get("start-after", [None])[0] encoding_type = querystring.get("encoding-type", [None])[0] - if continuation_token or start_after: - limit = continuation_token or start_after - all_keys = self._get_results_from_token(all_keys, limit) - - truncated_keys, is_truncated, next_continuation_token = self._truncate_result( - all_keys, max_keys + ( + truncated_keys, + is_truncated, + next_continuation_token, + ) = self.backend.list_objects_v2( + bucket=bucket, + prefix=prefix, + delimiter=delimiter, + continuation_token=continuation_token, + start_after=start_after, + max_keys=max_keys, ) + result_keys, result_folders = self._split_truncated_keys(truncated_keys) key_count = len(result_keys) + len(result_folders) @@ -796,29 +804,6 @@ def _split_truncated_keys(truncated_keys: Any) -> Any: # type: ignore[misc] result_folders.append(key) return result_keys, result_folders - def _get_results_from_token(self, result_keys: Any, token: Any) -> Any: - continuation_index = 0 - for key in result_keys: - if (key.name if isinstance(key, FakeKey) else key) > token: - break - continuation_index += 1 - return result_keys[continuation_index:] - - def _truncate_result(self, result_keys: Any, max_keys: int) -> Any: - if max_keys == 0: - result_keys = [] - is_truncated = True - next_continuation_token = None - elif len(result_keys) > max_keys: - is_truncated = "true" # type: ignore - result_keys = result_keys[:max_keys] - item = result_keys[-1] - next_continuation_token = item.name if isinstance(item, FakeKey) else item - else: - is_truncated = "false" # type: ignore - next_continuation_token = None - return result_keys, is_truncated, next_continuation_token - def _body_contains_location_constraint(self, body: bytes) -> bool: if body: try: diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index 550dd720f76e..bf8510863386 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -1441,7 +1441,7 @@ def test_list_objects_v2_truncate_combined_keys_and_folders(): assert len(resp["CommonPrefixes"]) == 1 assert resp["CommonPrefixes"][0]["Prefix"] == "1/" - last_tail = resp["NextContinuationToken"] + last_tail = resp["Contents"][-1]["Key"] resp = s3_client.list_objects_v2( Bucket="mybucket", MaxKeys=2, Prefix="", Delimiter="/", StartAfter=last_tail )