diff --git a/botocore/endpoint.py b/botocore/endpoint.py index 54442f4ad7..1c2cee068b 100644 --- a/botocore/endpoint.py +++ b/botocore/endpoint.py @@ -314,11 +314,6 @@ def _do_get_response(self, request, operation_model, context): response_dict, operation_model.output_shape ) parsed_response.update(customized_response_dict) - if updated_status_code := customized_response_dict.get( - 'updated_status_code' - ): - http_response.status_code = updated_status_code - del parsed_response['updated_status_code'] # Do a second parsing pass to pick up on any modeled error fields # NOTE: Ideally, we would push this down into the parser classes but # they currently have no reference to the operation or service model diff --git a/botocore/handlers.py b/botocore/handlers.py index 3639608652..2d17875cd3 100644 --- a/botocore/handlers.py +++ b/botocore/handlers.py @@ -1240,28 +1240,44 @@ def document_expires_shape(section, event_name, **kwargs): ) -def _handle_200_error( - operation_model, response_dict, customized_response_dict, **kwargs -): - if ( - not response_dict - or operation_model.has_streaming_output - or not operation_model.has_modeled_body_output - ): +def _handle_200_error(operation_model, response_dict, **kwargs): + # S3 can return a 200 OK response with an error embedded in the body. + # Conceptually, this should be handled like a 500 response in terms of + # raising exceptions and retries which we handle in _retry_200_error. + # This handler converts the 200 response to a 500 response to ensure + # correct error handling. + if not response_dict or operation_model.has_streaming_output: # Operations with streaming response blobs should be excluded as they # may contain customer content which mimics the form of an S3 error. return if _looks_like_special_case_error( response_dict['status_code'], response_dict['body'] ): - # The response_dict status code must be changed to be parsed as a 500 response. response_dict['status_code'] = 500 logger.debug( f"Error found for response with 200 status code: {response_dict['body']}. " - f"Changing status code to 500." + f"Changing the http_response status code to 500 will be propagated in " + f"the _retry_200_error handler." ) +def _retry_200_error(response, **kwargs): + # Adjusts the HTTP status code for responses that may contain errors + # embedded in a 200 OK response body. The _handle_200_error function + # modifies the parsed response status code to 500 if it detects an error. + # This function checks if the HTTP status code differs from the parsed + # status code and updates the HTTP response accordingly, ensuring + # correct handling for retries. + if response is None: + return + http_response, parsed = response + parsed_status_code = parsed.get('ResponseMetadata', {}).get( + 'HTTPStatusCode' + ) + if http_response.status_code != parsed_status_code: + http_response.status_code = parsed_status_code + + # This is a list of (event_name, handler). # When a Session is created, everything in this list will be # automatically registered with that Session. @@ -1336,6 +1352,7 @@ def _handle_200_error( ('before-call.ec2.CopySnapshot', inject_presigned_url_ec2), ('request-created', add_retry_headers), ('request-created.machinelearning.Predict', switch_host_machinelearning), + ('needs-retry.s3.*', _retry_200_error, REGISTER_FIRST), ('choose-signer.cognito-identity.GetId', disable_signing), ('choose-signer.cognito-identity.GetOpenIdToken', disable_signing), ('choose-signer.cognito-identity.UnlinkIdentity', disable_signing), diff --git a/tests/functional/configured_endpoint_urls/test_configured_endpoint_url.py b/tests/functional/configured_endpoint_urls/test_configured_endpoint_url.py index 51649d5bbc..40cc17e26d 100644 --- a/tests/functional/configured_endpoint_urls/test_configured_endpoint_url.py +++ b/tests/functional/configured_endpoint_urls/test_configured_endpoint_url.py @@ -147,9 +147,7 @@ def assert_endpoint_url_used_for_operation( ): http_stubber = ClientHTTPStubber(client) http_stubber.start() - http_stubber.add_response( - body=(b'' if operation == 'list_buckets' else None) - ) + http_stubber.add_response() # Call an operation on the client getattr(client, operation)(**params) diff --git a/tests/functional/test_credentials.py b/tests/functional/test_credentials.py index a45a73f14e..a6ba1f8b4b 100644 --- a/tests/functional/test_credentials.py +++ b/tests/functional/test_credentials.py @@ -1060,7 +1060,7 @@ def test_token_chosen_from_provider(self): session = Session(profile='sso-test') with SessionHTTPStubber(session) as stubber: self.add_credential_response(stubber) - stubber.add_response(body=b'') + stubber.add_response() with mock.patch.object( SSOTokenProvider, 'DEFAULT_CACHE_CLS', MockCache ): @@ -1155,7 +1155,7 @@ def test_credential_context_override(self): with SessionHTTPStubber(session) as stubber: s3 = session.create_client('s3') s3.meta.events.register('before-sign', self._add_fake_creds) - stubber.add_response(body=b'') + stubber.add_response() s3.list_buckets() request = stubber.requests[0] assert self.ACCESS_KEY in str(request.headers.get('Authorization')) diff --git a/tests/functional/test_regions.py b/tests/functional/test_regions.py index 6e22cc75e5..11a882f91f 100644 --- a/tests/functional/test_regions.py +++ b/tests/functional/test_regions.py @@ -502,7 +502,7 @@ def create_stubbed_client(self, service_name, region_name, **kwargs): def test_regionalized_client_endpoint_resolution(self): client, stubber = self.create_stubbed_client('s3', 'us-east-2') - stubber.add_response(body=b'') + stubber.add_response() client.list_buckets() self.assertEqual( stubber.requests[0].url, 'https://s3.us-east-2.amazonaws.com/' @@ -510,7 +510,7 @@ def test_regionalized_client_endpoint_resolution(self): def test_regionalized_client_with_unknown_region(self): client, stubber = self.create_stubbed_client('s3', 'not-real') - stubber.add_response(body=b'') + stubber.add_response() client.list_buckets() # Validate we don't fall back to partition endpoint for # regionalized services. diff --git a/tests/functional/test_s3.py b/tests/functional/test_s3.py index 422e74a455..04908b0551 100644 --- a/tests/functional/test_s3.py +++ b/tests/functional/test_s3.py @@ -434,12 +434,12 @@ def create_stubbed_s3_client(self, **kwargs): http_stubber.start() return client, http_stubber - def test_s3_copy_object_with_empty_response(self): + def test_s3_copy_object_with_incomplete_response(self): self.client, self.http_stubber = self.create_stubbed_s3_client( region_name="us-east-1" ) - empty_body = b"" + incomplete_body = b'\n\n\n' complete_body = ( b'\n\n' b"2020-04-21T21:03:31.000Z" b""s0mEcH3cK5uM"" ) - - self.http_stubber.add_response(status=200, body=empty_body) + self.http_stubber.add_response(status=200, body=incomplete_body) self.http_stubber.add_response(status=200, body=complete_body) response = self.client.copy_object( Bucket="bucket", @@ -551,7 +550,7 @@ def test_accesspoint_arn_with_custom_endpoint(self): self.client, http_stubber = self.create_stubbed_s3_client( endpoint_url="https://custom.com" ) - http_stubber.add_response(body=b'') + http_stubber.add_response() self.client.list_objects(Bucket=accesspoint_arn) expected_endpoint = "myendpoint-123456789012.custom.com" self.assert_endpoint(http_stubber.requests[0], expected_endpoint) @@ -564,7 +563,7 @@ def test_accesspoint_arn_with_custom_endpoint_and_dualstack(self): endpoint_url="https://custom.com", config=Config(s3={"use_dualstack_endpoint": True}), ) - http_stubber.add_response(body=b'') + http_stubber.add_response() self.client.list_objects(Bucket=accesspoint_arn) expected_endpoint = "myendpoint-123456789012.custom.com" self.assert_endpoint(http_stubber.requests[0], expected_endpoint) @@ -613,7 +612,7 @@ def test_signs_with_arn_region(self): self.client, self.http_stubber = self.create_stubbed_s3_client( region_name="us-east-1" ) - self.http_stubber.add_response(body=b'') + self.http_stubber.add_response() self.client.list_objects(Bucket=accesspoint_arn) self.assert_signing_region(self.http_stubber.requests[0], "us-west-2") @@ -737,7 +736,7 @@ def test_basic_outpost_arn(self): self.client, self.http_stubber = self.create_stubbed_s3_client( region_name="us-east-1" ) - self.http_stubber.add_response(body=b'') + self.http_stubber.add_response() self.client.list_objects(Bucket=outpost_arn) request = self.http_stubber.requests[0] self.assert_signing_name(request, "s3-outposts") @@ -759,7 +758,7 @@ def test_basic_outpost_arn_custom_endpoint(self): self.client, self.http_stubber = self.create_stubbed_s3_client( endpoint_url="https://custom.com", region_name="us-east-1" ) - self.http_stubber.add_response(body=b'') + self.http_stubber.add_response() self.client.list_objects(Bucket=outpost_arn) request = self.http_stubber.requests[0] self.assert_signing_name(request, "s3-outposts") @@ -963,7 +962,7 @@ def test_s3_object_lambda_arn_with_us_east_1(self): region_name="us-east-1", config=Config(s3={"use_arn_region": False}), ) - self.http_stubber.add_response(body=b'') + self.http_stubber.add_response() self.client.list_objects(Bucket=s3_object_lambda_arn) request = self.http_stubber.requests[0] self.assert_signing_name(request, "s3-object-lambda") @@ -981,7 +980,7 @@ def test_basic_s3_object_lambda_arn(self): self.client, self.http_stubber = self.create_stubbed_s3_client( region_name="us-east-1" ) - self.http_stubber.add_response(body=b'') + self.http_stubber.add_response() self.client.list_objects(Bucket=s3_object_lambda_arn) request = self.http_stubber.requests[0] self.assert_signing_name(request, "s3-object-lambda") @@ -1049,7 +1048,7 @@ def test_accesspoint_with_global_regions(self): config=Config(s3={"use_arn_region": True}), ) - self.http_stubber.add_response(body=b'') + self.http_stubber.add_response() self.client.list_objects(Bucket=s3_accesspoint_arn) request = self.http_stubber.requests[0] expected_endpoint = ( @@ -1063,7 +1062,7 @@ def test_accesspoint_with_global_regions(self): region_name="s3-external-1", ) - self.http_stubber.add_response(body=b'') + self.http_stubber.add_response() self.client.list_objects(Bucket=s3_accesspoint_arn) request = self.http_stubber.requests[0] expected_endpoint = ( @@ -1152,7 +1151,7 @@ def test_mrap_signing_algorithm_is_sigv4a(self): self.client, self.http_stubber = self.create_stubbed_s3_client( region_name="us-west-2" ) - self.http_stubber.add_response(body=b'') + self.http_stubber.add_response() self.client.list_objects(Bucket=s3_accesspoint_arn) request = self.http_stubber.requests[0] self._assert_sigv4a_used(request.headers) @@ -1239,7 +1238,7 @@ def _assert_mrap_endpoint( self.client, self.http_stubber = self.create_stubbed_s3_client( region_name=region, endpoint_url=endpoint_url, config=config ) - self.http_stubber.add_response(body=b'') + self.http_stubber.add_response() self.client.list_objects(Bucket=arn) request = self.http_stubber.requests[0] self.assert_endpoint(request, expected) @@ -1543,7 +1542,7 @@ def test_content_sha256_set_if_config_value_not_set_list_objects(self): "s3", self.region, config=config ) self.http_stubber = ClientHTTPStubber(self.client) - self.http_stubber.add_response(body=b'') + self.http_stubber.add_response() with self.http_stubber: self.client.list_objects(Bucket="foo") sent_headers = self.get_sent_headers() @@ -1561,7 +1560,7 @@ def test_content_sha256_set_s3_on_outpost(self): "s3", self.region, config=config ) self.http_stubber = ClientHTTPStubber(self.client) - self.http_stubber.add_response(body=b'') + self.http_stubber.add_response() with self.http_stubber: self.client.list_objects(Bucket=bucket) sent_headers = self.get_sent_headers() @@ -2210,7 +2209,7 @@ def test_checksums_included_in_expected_operations( """Validate expected calls include Content-MD5 header""" client = _create_s3_client() with ClientHTTPStubber(client) as stub: - stub.add_response(body=b'') + stub.add_response() call = getattr(client, operation) call(**operation_kwargs) assert "Content-MD5" in stub.requests[-1].headers @@ -3656,7 +3655,7 @@ def assert_correct_content_md5(self, request): self.assertEqual(content_md5, request.headers["Content-MD5"]) def test_escape_keys_in_xml_delete_objects(self): - self.http_stubber.add_response(body=b'') + self.http_stubber.add_response() with self.http_stubber: self.client.delete_objects( Bucket="mybucket", diff --git a/tests/functional/test_s3express.py b/tests/functional/test_s3express.py index a6e04bf661..390721ee12 100644 --- a/tests/functional/test_s3express.py +++ b/tests/functional/test_s3express.py @@ -280,7 +280,7 @@ def test_delete_objects_injects_correct_checksum( with ClientHTTPStubber(default_s3_client) as stubber: stubber.add_response(body=CREATE_SESSION_RESPONSE) - stubber.add_response(body=b'') + stubber.add_response() default_s3_client.delete_objects( Bucket=S3EXPRESS_BUCKET, diff --git a/tests/functional/test_useragent.py b/tests/functional/test_useragent.py index 90386e5157..79290459ab 100644 --- a/tests/functional/test_useragent.py +++ b/tests/functional/test_useragent.py @@ -27,7 +27,7 @@ class UACapHTTPStubber(ClientHTTPStubber): def __init__(self, obj_with_event_emitter): super().__init__(obj_with_event_emitter, strict=False) - self.add_response(body=b'') # expect exactly one request + self.add_response() # expect exactly one request @property def captured_ua_string(self): diff --git a/tests/unit/test_handlers.py b/tests/unit/test_handlers.py index f4f1ec19c4..f75eb1cb1c 100644 --- a/tests/unit/test_handlers.py +++ b/tests/unit/test_handlers.py @@ -1148,6 +1148,49 @@ def test_non_ascii_characters(self): ) +class TestRetryHandlerOrder(BaseSessionTest): + def get_handler_names(self, responses): + names = [] + for response in responses: + handler = response[0] + if hasattr(handler, '__name__'): + names.append(handler.__name__) + elif hasattr(handler, '__class__'): + names.append(handler.__class__.__name__) + else: + names.append(str(handler)) + return names + + def test_s3_special_case_is_before_other_retry(self): + client = self.session.create_client('s3') + service_model = self.session.get_service_model('s3') + operation = service_model.operation_model('CopyObject') + responses = client.meta.events.emit( + 'needs-retry.s3.CopyObject', + request_dict={'context': {}}, + response=(mock.Mock(), mock.Mock()), + endpoint=mock.Mock(), + operation=operation, + attempts=1, + caught_exception=None, + ) + # This is implementation specific, but we're trying to verify that + # the _retry_200_error is before any of the retry logic in + # botocore.retryhandlers. + # Technically, as long as the relative order is preserved, we don't + # care about the absolute order. + names = self.get_handler_names(responses) + self.assertIn('_retry_200_error', names) + self.assertIn('RetryHandler', names) + s3_200_handler = names.index('_retry_200_error') + general_retry_handler = names.index('RetryHandler') + self.assertTrue( + s3_200_handler < general_retry_handler, + "S3 200 error handler was supposed to be before " + "the general retry handler, but it was not.", + ) + + class BaseMD5Test(BaseSessionTest): def setUp(self, **environ): super().setUp(**environ) @@ -1880,7 +1923,6 @@ def test_document_response_params_without_expires(document_expires_mocks): def operation_model_for_200_error(): operation_model = mock.Mock() operation_model.has_streaming_output = False - operation_model.has_modeled_body_output = True return operation_model @@ -1900,54 +1942,55 @@ def response_dict_for_200_error(): def test_500_status_code_set_for_200_response( operation_model_for_200_error, response_dict_for_200_error ): - customized_response_dict = {} handlers._handle_200_error( - operation_model_for_200_error, - response_dict_for_200_error, - customized_response_dict, + operation_model_for_200_error, response_dict_for_200_error ) assert response_dict_for_200_error['status_code'] == 500 - assert customized_response_dict.get('updated_status_code') == 500 def test_200_response_with_no_error_left_untouched( operation_model_for_200_error, response_dict_for_200_error ): response_dict_for_200_error['body'] = b"" - customized_response_dict = {} handlers._handle_200_error( - operation_model_for_200_error, - response_dict_for_200_error, - customized_response_dict, + operation_model_for_200_error, response_dict_for_200_error ) # We don't touch the status code since there are no errors present. assert response_dict_for_200_error['status_code'] == 200 - assert customized_response_dict == {} def test_200_response_with_streaming_output_left_untouched( + operation_model_for_200_error, response_dict_for_200_error, ): - operation_model = mock.Mock() - operation_model.has_streaming_output = True - customized_response_dict = {} + operation_model_for_200_error.has_streaming_output = True handlers._handle_200_error( - operation_model, response_dict_for_200_error, customized_response_dict + operation_model_for_200_error, response_dict_for_200_error ) # We don't touch the status code on streaming operations. assert response_dict_for_200_error['status_code'] == 200 - assert customized_response_dict == {} def test_200_response_with_no_body_left_untouched( operation_model_for_200_error, response_dict_for_200_error ): - operation_model_for_200_error.has_modeled_body_output = False - customized_response_dict = {} + response_dict_for_200_error['body'] = b"" handlers._handle_200_error( - operation_model_for_200_error, - response_dict_for_200_error, - customized_response_dict, + operation_model_for_200_error, response_dict_for_200_error ) assert response_dict_for_200_error['status_code'] == 200 - assert customized_response_dict == {} + + +def test_http_status_code_updated_to_retry_200_response(): + http_response = mock.Mock() + http_response.status_code = 200 + parsed = {} + parsed.setdefault('ResponseMetadata', {})['HTTPStatusCode'] = 500 + handlers._retry_200_error((http_response, parsed)) + assert http_response.status_code == 500 + + +def test_500_response_can_be_none(): + # A 500 response can raise an exception, which means the response + # object is None. We need to handle this case. + handlers._retry_200_error(None) diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index f5c75cb0cc..da95a18bea 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -737,91 +737,6 @@ def test_not_streaming_output_for_operation(self): self.assertEqual(operation.get_streaming_output(), None) -class TestOperationModelBody(unittest.TestCase): - def setUp(self): - super().setUp() - self.model = { - 'operations': { - 'OperationName': { - 'name': 'OperationName', - 'input': { - 'shape': 'OperationRequest', - }, - 'output': { - 'shape': 'OperationResponse', - }, - }, - 'NoBodyOperation': { - 'name': 'NoBodyOperation', - 'input': {'shape': 'NoBodyOperationRequest'}, - 'output': {'shape': 'NoBodyOperationResponse'}, - }, - }, - 'shapes': { - 'OperationRequest': { - 'type': 'structure', - 'members': { - 'String': { - 'shape': 'stringType', - }, - "Body": { - 'shape': 'blobType', - }, - }, - 'payload': 'Body', - }, - 'OperationResponse': { - 'type': 'structure', - 'members': { - 'String': { - 'shape': 'stringType', - }, - "Body": { - 'shape': 'blobType', - }, - }, - 'payload': 'Body', - }, - 'NoBodyOperationRequest': { - 'type': 'structure', - 'members': { - 'data': { - 'location': 'header', - 'locationName': 'x-amz-data', - 'shape': 'stringType', - } - }, - }, - 'NoBodyOperationResponse': { - 'type': 'structure', - 'members': { - 'data': { - 'location': 'header', - 'locationName': 'x-amz-data', - 'shape': 'stringType', - } - }, - }, - 'stringType': { - 'type': 'string', - }, - 'blobType': {'type': 'blob'}, - }, - } - - def test_modeled_body_for_operation_with_body(self): - service_model = model.ServiceModel(self.model) - operation = service_model.operation_model('OperationName') - self.assertTrue(operation.has_modeled_body_input) - self.assertTrue(operation.has_modeled_body_output) - - def test_modeled_body_for_operation_with_no_body(self): - service_model = model.ServiceModel(self.model) - operation = service_model.operation_model('NoBodyOperation') - self.assertFalse(operation.has_modeled_body_input) - self.assertFalse(operation.has_modeled_body_output) - - class TestDeepMerge(unittest.TestCase): def setUp(self): self.shapes = { diff --git a/tests/unit/test_s3_addressing.py b/tests/unit/test_s3_addressing.py index 1629dbf880..7150d5b4ce 100644 --- a/tests/unit/test_s3_addressing.py +++ b/tests/unit/test_s3_addressing.py @@ -31,14 +31,12 @@ def setUp(self): set_list_objects_encoding_type_url, ) - def get_prepared_request( - self, operation, params, force_hmacv1=False, body=None - ): + def get_prepared_request(self, operation, params, force_hmacv1=False): if force_hmacv1: self.session.register('choose-signer', self.enable_hmacv1) client = self.session.create_client('s3', self.region_name) with ClientHTTPStubber(client) as http_stubber: - http_stubber.add_response(body=body) + http_stubber.add_response() getattr(client, operation)(**params) # Return the request that was sent over the wire. return http_stubber.requests[0] @@ -49,7 +47,7 @@ def enable_hmacv1(self, **kwargs): def test_list_objects_dns_name(self): params = {'Bucket': 'safename'} prepared_request = self.get_prepared_request( - 'list_objects', params, force_hmacv1=True, body=b'' + 'list_objects', params, force_hmacv1=True ) self.assertEqual( prepared_request.url, 'https://safename.s3.amazonaws.com/' @@ -58,7 +56,7 @@ def test_list_objects_dns_name(self): def test_list_objects_non_dns_name(self): params = {'Bucket': 'un_safe_name'} prepared_request = self.get_prepared_request( - 'list_objects', params, force_hmacv1=True, body=b'' + 'list_objects', params, force_hmacv1=True ) self.assertEqual( prepared_request.url, 'https://s3.amazonaws.com/un_safe_name' @@ -68,7 +66,7 @@ def test_list_objects_dns_name_non_classic(self): self.region_name = 'us-west-2' params = {'Bucket': 'safename'} prepared_request = self.get_prepared_request( - 'list_objects', params, force_hmacv1=True, body=b'' + 'list_objects', params, force_hmacv1=True ) self.assertEqual( prepared_request.url, @@ -80,9 +78,7 @@ def test_list_objects_unicode_query_string_eu_central_1(self): params = OrderedDict( [('Bucket', 'safename'), ('Marker', '\xe4\xf6\xfc-01.txt')] ) - prepared_request = self.get_prepared_request( - 'list_objects', params, body=b'' - ) + prepared_request = self.get_prepared_request('list_objects', params) self.assertEqual( prepared_request.url, ( @@ -94,9 +90,7 @@ def test_list_objects_unicode_query_string_eu_central_1(self): def test_list_objects_in_restricted_regions(self): self.region_name = 'us-gov-west-1' params = {'Bucket': 'safename'} - prepared_request = self.get_prepared_request( - 'list_objects', params, body=b'' - ) + prepared_request = self.get_prepared_request('list_objects', params) # Note how we keep the region specific endpoint here. self.assertEqual( prepared_request.url, @@ -106,9 +100,7 @@ def test_list_objects_in_restricted_regions(self): def test_list_objects_in_fips(self): self.region_name = 'fips-us-gov-west-1' params = {'Bucket': 'safename'} - prepared_request = self.get_prepared_request( - 'list_objects', params, body=b'' - ) + prepared_request = self.get_prepared_request('list_objects', params) # Note how we keep the region specific endpoint here. self.assertEqual( prepared_request.url, @@ -118,9 +110,7 @@ def test_list_objects_in_fips(self): def test_list_objects_non_dns_name_non_classic(self): self.region_name = 'us-west-2' params = {'Bucket': 'un_safe_name'} - prepared_request = self.get_prepared_request( - 'list_objects', params, body=b'' - ) + prepared_request = self.get_prepared_request('list_objects', params) self.assertEqual( prepared_request.url, 'https://s3.us-west-2.amazonaws.com/un_safe_name',