Skip to content

Commit

Permalink
move _has_modeled_body function and update tests that need bodies
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgromero committed Oct 8, 2024
1 parent 2debf33 commit d4171b7
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 46 deletions.
12 changes: 1 addition & 11 deletions botocore/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def check_for_200_error(
if (
http_response is None
or operation_model.has_streaming_output
or not _has_modeled_body(operation_model)
or not operation_model.has_modeled_body_output
):
# A None response can happen if an exception is raised while
# trying to retrieve the response. See Endpoint._get_response().
Expand Down Expand Up @@ -190,16 +190,6 @@ def _looks_like_special_case_error(http_response):
return False


def _has_modeled_body(operation_model):
if output_shape := operation_model.output_shape:
for member_shape in output_shape.members.values():
if not member_shape.serialization.get('location'):
# If any member is not bound to a location,
# we can expect a body
return True
return False


def set_operation_specific_signer(context, signing_name, **kwargs):
"""Choose the operation-specific signer.
Expand Down
20 changes: 20 additions & 0 deletions botocore/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,26 @@ def _get_streaming_body(self, shape):
return payload_shape
return None

@CachedProperty
def has_modeled_body_input(self):
return self._has_modeled_body(self.input_shape)

@CachedProperty
def has_modeled_body_output(self):
return self._has_modeled_body(self.output_shape)

def _has_modeled_body(self, shape):
"""
Determines if an operation has a modeled body.
If any member is not bound to a location, we can expect a body.
"""
if shape is None:
return False
for member_shape in shape.members.values():
if not member_shape.serialization.get('location'):
return True
return False

def __repr__(self):
return f'{self.__class__.__name__}(name={self.name})'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def assert_endpoint_url_used_for_operation(
http_stubber = ClientHTTPStubber(client)
http_stubber.start()
http_stubber.add_response(
body=(b'<Test></Test>' if operation == 'list_buckets' else None)
body=(b'<Test/>' if operation == 'list_buckets' else None)
)

# Call an operation on the client
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ def test_token_chosen_from_provider(self):
session = Session(profile='sso-test')
with SessionHTTPStubber(session) as stubber:
self.add_credential_response(stubber)
stubber.add_response(body=b'<Test></Test>')
stubber.add_response(body=b'<Test/>')
with mock.patch.object(
SSOTokenProvider, 'DEFAULT_CACHE_CLS', MockCache
):
Expand Down Expand Up @@ -1155,7 +1155,7 @@ def test_credential_context_override(self):
with SessionHTTPStubber(session) as stubber:
s3 = session.create_client('s3')
s3.meta.events.register('before-sign', self._add_fake_creds)
stubber.add_response(body=b'<Test></Test>')
stubber.add_response(body=b'<Test/>')
s3.list_buckets()
request = stubber.requests[0]
assert self.ACCESS_KEY in str(request.headers.get('Authorization'))
4 changes: 2 additions & 2 deletions tests/functional/test_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,15 +502,15 @@ def create_stubbed_client(self, service_name, region_name, **kwargs):

def test_regionalized_client_endpoint_resolution(self):
client, stubber = self.create_stubbed_client('s3', 'us-east-2')
stubber.add_response(body=b'<Test></Test>')
stubber.add_response(body=b'<Test/>')
client.list_buckets()
self.assertEqual(
stubber.requests[0].url, 'https://s3.us-east-2.amazonaws.com/'
)

def test_regionalized_client_with_unknown_region(self):
client, stubber = self.create_stubbed_client('s3', 'not-real')
stubber.add_response(body=b'<Test></Test>')
stubber.add_response(body=b'<Test/>')
client.list_buckets()
# Validate we don't fall back to partition endpoint for
# regionalized services.
Expand Down
30 changes: 15 additions & 15 deletions tests/functional/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def test_accesspoint_arn_with_custom_endpoint(self):
self.client, http_stubber = self.create_stubbed_s3_client(
endpoint_url="https://custom.com"
)
http_stubber.add_response(body=b'<Test></Test>')
http_stubber.add_response(body=b'<Test/>')
self.client.list_objects(Bucket=accesspoint_arn)
expected_endpoint = "myendpoint-123456789012.custom.com"
self.assert_endpoint(http_stubber.requests[0], expected_endpoint)
Expand All @@ -566,7 +566,7 @@ def test_accesspoint_arn_with_custom_endpoint_and_dualstack(self):
endpoint_url="https://custom.com",
config=Config(s3={"use_dualstack_endpoint": True}),
)
http_stubber.add_response(body=b'<Test></Test>')
http_stubber.add_response(body=b'<Test/>')
self.client.list_objects(Bucket=accesspoint_arn)
expected_endpoint = "myendpoint-123456789012.custom.com"
self.assert_endpoint(http_stubber.requests[0], expected_endpoint)
Expand Down Expand Up @@ -615,7 +615,7 @@ def test_signs_with_arn_region(self):
self.client, self.http_stubber = self.create_stubbed_s3_client(
region_name="us-east-1"
)
self.http_stubber.add_response(body=b'<Test></Test>')
self.http_stubber.add_response(body=b'<Test/>')
self.client.list_objects(Bucket=accesspoint_arn)
self.assert_signing_region(self.http_stubber.requests[0], "us-west-2")

Expand Down Expand Up @@ -739,7 +739,7 @@ def test_basic_outpost_arn(self):
self.client, self.http_stubber = self.create_stubbed_s3_client(
region_name="us-east-1"
)
self.http_stubber.add_response(body=b'<Test></Test>')
self.http_stubber.add_response(body=b'<Test/>')
self.client.list_objects(Bucket=outpost_arn)
request = self.http_stubber.requests[0]
self.assert_signing_name(request, "s3-outposts")
Expand All @@ -761,7 +761,7 @@ def test_basic_outpost_arn_custom_endpoint(self):
self.client, self.http_stubber = self.create_stubbed_s3_client(
endpoint_url="https://custom.com", region_name="us-east-1"
)
self.http_stubber.add_response(body=b'<Test></Test>')
self.http_stubber.add_response(body=b'<Test/>')
self.client.list_objects(Bucket=outpost_arn)
request = self.http_stubber.requests[0]
self.assert_signing_name(request, "s3-outposts")
Expand Down Expand Up @@ -965,7 +965,7 @@ def test_s3_object_lambda_arn_with_us_east_1(self):
region_name="us-east-1",
config=Config(s3={"use_arn_region": False}),
)
self.http_stubber.add_response(body=b'<Test></Test>')
self.http_stubber.add_response(body=b'<Test/>')
self.client.list_objects(Bucket=s3_object_lambda_arn)
request = self.http_stubber.requests[0]
self.assert_signing_name(request, "s3-object-lambda")
Expand All @@ -983,7 +983,7 @@ def test_basic_s3_object_lambda_arn(self):
self.client, self.http_stubber = self.create_stubbed_s3_client(
region_name="us-east-1"
)
self.http_stubber.add_response(body=b'<Test></Test>')
self.http_stubber.add_response(body=b'<Test/>')
self.client.list_objects(Bucket=s3_object_lambda_arn)
request = self.http_stubber.requests[0]
self.assert_signing_name(request, "s3-object-lambda")
Expand Down Expand Up @@ -1051,7 +1051,7 @@ def test_accesspoint_with_global_regions(self):
config=Config(s3={"use_arn_region": True}),
)

self.http_stubber.add_response(body=b'<Test></Test>')
self.http_stubber.add_response(body=b'<Test/>')
self.client.list_objects(Bucket=s3_accesspoint_arn)
request = self.http_stubber.requests[0]
expected_endpoint = (
Expand All @@ -1065,7 +1065,7 @@ def test_accesspoint_with_global_regions(self):
region_name="s3-external-1",
)

self.http_stubber.add_response(body=b'<Test></Test>')
self.http_stubber.add_response(body=b'<Test/>')
self.client.list_objects(Bucket=s3_accesspoint_arn)
request = self.http_stubber.requests[0]
expected_endpoint = (
Expand Down Expand Up @@ -1154,7 +1154,7 @@ def test_mrap_signing_algorithm_is_sigv4a(self):
self.client, self.http_stubber = self.create_stubbed_s3_client(
region_name="us-west-2"
)
self.http_stubber.add_response()
self.http_stubber.add_response(body=b'<Test/>')
self.client.list_objects(Bucket=s3_accesspoint_arn)
request = self.http_stubber.requests[0]
self._assert_sigv4a_used(request.headers)
Expand Down Expand Up @@ -1241,7 +1241,7 @@ def _assert_mrap_endpoint(
self.client, self.http_stubber = self.create_stubbed_s3_client(
region_name=region, endpoint_url=endpoint_url, config=config
)
self.http_stubber.add_response()
self.http_stubber.add_response(body=b'<Test/>')
self.client.list_objects(Bucket=arn)
request = self.http_stubber.requests[0]
self.assert_endpoint(request, expected)
Expand Down Expand Up @@ -1545,7 +1545,7 @@ def test_content_sha256_set_if_config_value_not_set_list_objects(self):
"s3", self.region, config=config
)
self.http_stubber = ClientHTTPStubber(self.client)
self.http_stubber.add_response(body=b'<Test></Test>')
self.http_stubber.add_response(body=b'<Test/>')
with self.http_stubber:
self.client.list_objects(Bucket="foo")
sent_headers = self.get_sent_headers()
Expand All @@ -1563,7 +1563,7 @@ def test_content_sha256_set_s3_on_outpost(self):
"s3", self.region, config=config
)
self.http_stubber = ClientHTTPStubber(self.client)
self.http_stubber.add_response(body=b'<Test></Test>')
self.http_stubber.add_response(body=b'<Test/>')
with self.http_stubber:
self.client.list_objects(Bucket=bucket)
sent_headers = self.get_sent_headers()
Expand Down Expand Up @@ -2212,7 +2212,7 @@ def test_checksums_included_in_expected_operations(
"""Validate expected calls include Content-MD5 header"""
client = _create_s3_client()
with ClientHTTPStubber(client) as stub:
stub.add_response(body=b'<Test></Test>')
stub.add_response(body=b'<Test/>')
call = getattr(client, operation)
call(**operation_kwargs)
assert "Content-MD5" in stub.requests[-1].headers
Expand Down Expand Up @@ -3658,7 +3658,7 @@ def assert_correct_content_md5(self, request):
self.assertEqual(content_md5, request.headers["Content-MD5"])

def test_escape_keys_in_xml_delete_objects(self):
self.http_stubber.add_response(body=b'<Test></Test>')
self.http_stubber.add_response(body=b'<Test/>')
with self.http_stubber:
self.client.delete_objects(
Bucket="mybucket",
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/test_s3express.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def test_delete_objects_injects_correct_checksum(

with ClientHTTPStubber(default_s3_client) as stubber:
stubber.add_response(body=CREATE_SESSION_RESPONSE)
stubber.add_response(body=b'<Test></Test>')
stubber.add_response(body=b'<Test/>')

default_s3_client.delete_objects(
Bucket=S3EXPRESS_BUCKET,
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/test_useragent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class UACapHTTPStubber(ClientHTTPStubber):

def __init__(self, obj_with_event_emitter):
super().__init__(obj_with_event_emitter, strict=False)
self.add_response(body=b'<Test></Test>') # expect exactly one request
self.add_response(body=b'<Test/>') # expect exactly one request

@property
def captured_ua_string(self):
Expand Down
8 changes: 2 additions & 6 deletions tests/unit/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,9 +1880,7 @@ def test_document_response_params_without_expires(document_expires_mocks):
def operation_model_for_200_error():
operation_model = mock.Mock()
operation_model.has_streaming_output = False
operation_model.output_shape = mock.Mock()
operation_model.output_shape.members = {'member': mock.Mock()}
operation_model.output_shape.members['member'].serialization = {}
operation_model.has_modeled_body_output = True
return operation_model


Expand Down Expand Up @@ -1944,9 +1942,7 @@ def test_200_response_with_streaming_output_left_untouched(
def test_200_response_with_no_body_left_untouched(
operation_model_for_200_error, response_dict, http_response
):
operation_model_for_200_error.output_shape.members[
'member'
].serialization = {'location': 'header'}
operation_model_for_200_error.has_modeled_body_output = False
handlers.check_for_200_error(
operation_model_for_200_error, response_dict, http_response
)
Expand Down
85 changes: 85 additions & 0 deletions tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,91 @@ def test_not_streaming_output_for_operation(self):
self.assertEqual(operation.get_streaming_output(), None)


class TestOperationModelBody(unittest.TestCase):
def setUp(self):
super().setUp()
self.model = {
'operations': {
'OperationName': {
'name': 'OperationName',
'input': {
'shape': 'OperationRequest',
},
'output': {
'shape': 'OperationResponse',
},
},
'NoBodyOperation': {
'name': 'NoBodyOperation',
'input': {'shape': 'NoBodyOperationRequest'},
'output': {'shape': 'NoBodyOperationResponse'},
},
},
'shapes': {
'OperationRequest': {
'type': 'structure',
'members': {
'String': {
'shape': 'stringType',
},
"Body": {
'shape': 'blobType',
},
},
'payload': 'Body',
},
'OperationResponse': {
'type': 'structure',
'members': {
'String': {
'shape': 'stringType',
},
"Body": {
'shape': 'blobType',
},
},
'payload': 'Body',
},
'NoBodyOperationRequest': {
'type': 'structure',
'members': {
'data': {
'location': 'header',
'locationName': 'x-amz-data',
'shape': 'stringType',
}
},
},
'NoBodyOperationResponse': {
'type': 'structure',
'members': {
'data': {
'location': 'header',
'locationName': 'x-amz-data',
'shape': 'stringType',
}
},
},
'stringType': {
'type': 'string',
},
'blobType': {'type': 'blob'},
},
}

def test_modeled_body_for_operation_with_body(self):
service_model = model.ServiceModel(self.model)
operation = service_model.operation_model('OperationName')
self.assertTrue(operation.has_modeled_body_input)
self.assertTrue(operation.has_modeled_body_output)

def test_modeled_body_for_operation_with_no_body(self):
service_model = model.ServiceModel(self.model)
operation = service_model.operation_model('NoBodyOperation')
self.assertFalse(operation.has_modeled_body_input)
self.assertFalse(operation.has_modeled_body_output)


class TestDeepMerge(unittest.TestCase):
def setUp(self):
self.shapes = {
Expand Down
Loading

0 comments on commit d4171b7

Please sign in to comment.