Skip to content

Commit

Permalink
[v2] Multi-auth trait support (#8942)
Browse files Browse the repository at this point in the history
  • Loading branch information
aemous authored Oct 30, 2024
1 parent 78b9214 commit 9e6b1c4
Show file tree
Hide file tree
Showing 15 changed files with 343 additions and 14 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/feature-signing-67047.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "feature",
"category": "signing",
"description": "Adds internal support for the new 'auth' trait to allow a priority list of auth types for a service or operation."
}
14 changes: 14 additions & 0 deletions awscli/botocore/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,14 @@ def compute_client_args(self, service_model, client_config,
),
user_agent_extra=client_config.user_agent_extra,
user_agent_appid=client_config.user_agent_appid,
sigv4a_signing_region_set=(
client_config.sigv4a_signing_region_set
),
)
self._compute_retry_config(config_kwargs)
self._compute_request_compression_config(config_kwargs)
self._compute_user_agent_appid_config(config_kwargs)
self._compute_sigv4a_signing_region_set_config(config_kwargs)
s3_config = self.compute_s3_config(client_config)

is_s3_service = self._is_s3_service(service_name)
Expand Down Expand Up @@ -576,3 +580,13 @@ def _compute_user_agent_appid_config(self, config_kwargs):
f'maximum length of {USERAGENT_APPID_MAXLEN} characters.'
)
config_kwargs['user_agent_appid'] = user_agent_appid

def _compute_sigv4a_signing_region_set_config(self, config_kwargs):
sigv4a_signing_region_set = config_kwargs.get(
'sigv4a_signing_region_set'
)
if sigv4a_signing_region_set is None:
sigv4a_signing_region_set = self._config_store.get_config_variable(
'sigv4a_signing_region_set'
)
config_kwargs['sigv4a_signing_region_set'] = sigv4a_signing_region_set
27 changes: 26 additions & 1 deletion awscli/botocore/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
urlsplit,
urlunsplit,
)
from botocore.exceptions import NoAuthTokenError, NoCredentialsError
from botocore.exceptions import (
NoAuthTokenError,
NoCredentialsError,
UnknownSignatureVersionError,
UnsupportedSignatureVersionError,
)
from botocore.utils import (
is_valid_ipv6_endpoint_url,
normalize_url_path,
Expand Down Expand Up @@ -851,6 +856,19 @@ def add_auth(self, request):
# a separate utility module to avoid any potential circular import.
import botocore.crt.auth

def resolve_auth_type(auth_trait):
for auth_type in auth_trait:
if auth_type == 'smithy.api#noAuth':
return AUTH_TYPE_TO_SIGNATURE_VERSION[auth_type]
elif auth_type in AUTH_TYPE_TO_SIGNATURE_VERSION:
signature_version = AUTH_TYPE_TO_SIGNATURE_VERSION[auth_type]
if signature_version in AUTH_TYPE_MAPS:
return signature_version
else:
raise UnknownSignatureVersionError(signature_version=auth_type)
raise UnsupportedSignatureVersionError(signature_version=auth_trait)


# Defined at the bottom instead of the top of the module because the Auth
# classes weren't defined yet.
AUTH_TYPE_MAPS = {
Expand All @@ -870,3 +888,10 @@ def add_auth(self, request):
'v4-s3express-presign-post': S3ExpressPostAuth,
'bearer': BearerAuth,
}

AUTH_TYPE_TO_SIGNATURE_VERSION = {
'aws.auth#sigv4': 'v4',
'aws.auth#sigv4a': 'v4a',
'smithy.api#httpBearerAuth': 'bearer',
'smithy.api#noAuth': 'none',
}
15 changes: 10 additions & 5 deletions awscli/botocore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from botocore import UNSIGNED, waiter, xform_name
from botocore.args import ClientArgsCreator
from botocore.auth import AUTH_TYPE_MAPS
from botocore.auth import AUTH_TYPE_MAPS, resolve_auth_type
from botocore.awsrequest import prepare_request_dict
from botocore.compress import maybe_compress_request
from botocore.config import Config
Expand Down Expand Up @@ -118,13 +118,17 @@ def create_client(self, service_name, region_name, is_secure=True,
cls = self._create_client_class(service_name, service_model)
region_name, client_config = self._normalize_fips_region(
region_name, client_config)
if auth := service_model.metadata.get('auth'):
service_signature_version = resolve_auth_type(auth)
else:
service_signature_version = service_model.metadata.get(
'signatureVersion'
)
endpoint_bridge = ClientEndpointBridge(
self._endpoint_resolver, scoped_config, client_config,
service_signing_name=service_model.metadata.get('signingName'),
config_store=self._config_store,
service_signature_version=service_model.metadata.get(
'signatureVersion'
),
service_signature_version=service_signature_version,
)
client_args = self._get_client_args(
service_model, region_name, is_secure, endpoint_url,
Expand Down Expand Up @@ -678,7 +682,8 @@ def _make_api_call(self, operation_name, api_params):
'client_region': self.meta.region_name,
'client_config': self.meta.config,
'has_streaming_input': operation_model.has_streaming_input,
'auth_type': operation_model.auth_type,
'auth_type': operation_model.resolved_auth_type,
'unsigned_payload': operation_model.unsigned_payload,
}

api_params = self._emit_api_params(
Expand Down
6 changes: 6 additions & 0 deletions awscli/botocore/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ class Config(object):
set to True.
Defaults to None.
:type sigv4a_signing_region_set: string
:param sigv4a_signing_region_set: A set of AWS regions to apply the signature for
when using SigV4a for signing. Set to ``*`` to represent all regions.
Defaults to None.
"""
OPTION_DEFAULTS = OrderedDict([
('region_name', None),
Expand All @@ -212,6 +217,7 @@ class Config(object):
('ignore_configured_endpoint_urls', None),
('request_min_compression_size_bytes', None),
('disable_request_compression', None),
('sigv4a_signing_region_set', None),
])

def __init__(self, *args, **kwargs):
Expand Down
6 changes: 6 additions & 0 deletions awscli/botocore/configprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@
False,
utils.ensure_boolean,
),
'sigv4a_signing_region_set': (
'sigv4a_signing_region_set',
'AWS_SIGV4A_SIGNING_REGION_SET',
None,
None,
),
}
# A mapping for the s3 specific configuration vars. These are the configuration
# vars that typically go in the s3 section of the config file. This mapping
Expand Down
2 changes: 1 addition & 1 deletion awscli/botocore/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ class UnknownClientMethodError(BotoCoreError):

class UnsupportedSignatureVersionError(BotoCoreError):
"""Error when trying to use an unsupported Signature Version."""
fmt = 'Signature version is not supported: {signature_version}'
fmt = 'Signature version(s) are not supported: {signature_version}'


class ClientError(Exception):
Expand Down
17 changes: 16 additions & 1 deletion awscli/botocore/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,20 @@ def set_operation_specific_signer(context, signing_name, **kwargs):
if auth_type == 'bearer':
return 'bearer'

# If the operation needs an unsigned body, we set additional context
# allowing the signer to be aware of this.
if context.get('unsigned_payload') or auth_type == 'v4-unsigned-body':
context['payload_signing_enabled'] = False

if auth_type.startswith('v4'):
if auth_type == 'v4-s3express':
return auth_type

if auth_type == 'v4a':
# If sigv4a is chosen, we must add additional signing config for
# global signature.
signing = {'region': '*', 'signing_name': signing_name}
region = _resolve_sigv4a_region(context)
signing = {'region': region, 'signing_name': signing_name}
if 'signing' in context:
context['signing'].update(signing)
else:
Expand All @@ -212,6 +218,15 @@ def set_operation_specific_signer(context, signing_name, **kwargs):
return signature_version


def _resolve_sigv4a_region(context):
region = None
if 'client_config' in context:
region = context['client_config'].sigv4a_signing_region_set
if not region and context.get('signing', {}).get('region'):
region = context['signing']['region']
return region or '*'


def decode_console_output(parsed, **kwargs):
if 'Output' in parsed:
try:
Expand Down
15 changes: 15 additions & 0 deletions awscli/botocore/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections import defaultdict
from typing import NamedTuple, Union

from botocore.auth import resolve_auth_type
from botocore.compat import OrderedDict
from botocore.exceptions import (
MissingServiceIdError,
Expand Down Expand Up @@ -587,10 +588,24 @@ def context_parameters(self):
def request_compression(self):
return self._operation_model.get('requestcompression')

@CachedProperty
def auth(self):
return self._operation_model.get('auth')

@CachedProperty
def auth_type(self):
return self._operation_model.get('authtype')

@CachedProperty
def resolved_auth_type(self):
if self.auth:
return resolve_auth_type(self.auth)
return self.auth_type

@CachedProperty
def unsigned_payload(self):
return self._operation_model.get('unsignedPayload')

@CachedProperty
def error_shapes(self):
shapes = self._operation_model.get("errors", [])
Expand Down
4 changes: 3 additions & 1 deletion awscli/botocore/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,9 @@ def auth_schemes_to_signing_ctx(self, auth_schemes):
signing_context['region'] = scheme['signingRegion']
elif 'signingRegionSet' in scheme:
if len(scheme['signingRegionSet']) > 0:
signing_context['region'] = scheme['signingRegionSet'][0]
signing_context['region'] = ','.join(
scheme['signingRegionSet']
)
if 'signingName' in scheme:
signing_context.update(signing_name=scheme['signingName'])
if 'disableDoubleEncoding' in scheme:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ filterwarnings = [
'default:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning'
]
markers = [
"slow"
"slow: marks tests as slow",
"validates_models: marks tests as one which validates service models",
]

[tool.black]
Expand Down
123 changes: 123 additions & 0 deletions tests/functional/botocore/test_auth_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import pytest

from botocore.config import Config
from tests import create_session, mock

# In the future, a service may have a list of credentials requirements where one
# signature may fail and others may succeed. e.g. a service may want to use bearer
# auth but fall back to sigv4 if a token isn't available. There's currently no way to do
# this in botocore, so this test ensures we handle this gracefully when the need arises.


# The dictionary's value here needs to be hashable to be added to the set below; any
# new auth types with multiple requirements should be added in a comma-separated list
AUTH_TYPE_REQUIREMENTS = {
'aws.auth#sigv4': 'credentials',
'aws.auth#sigv4a': 'credentials',
'smithy.api#httpBearerAuth': 'bearer_token',
'smithy.api#noAuth': 'none',
}


def _all_test_cases():
session = create_session()
loader = session.get_component('data_loader')

services = loader.list_available_services('service-2')
auth_services = []
auth_operations = []

for service in services:
service_model = session.get_service_model(service)
auth_config = service_model.metadata.get('auth', {})
if auth_config:
auth_services.append([service, auth_config])
for operation in service_model.operation_names:
operation_model = service_model.operation_model(operation)
if operation_model.auth:
auth_operations.append([service, operation_model])
return auth_services, auth_operations


AUTH_SERVICES, AUTH_OPERATIONS = _all_test_cases()


@pytest.mark.validates_models
@pytest.mark.parametrize("auth_service, auth_config", AUTH_SERVICES)
def test_all_requirements_match_for_service(auth_service, auth_config):
# Validates that all service-level signature types have the same requirements
message = f'Found mixed signer requirements for service: {auth_service}'
assert_all_requirements_match(auth_config, message)


@pytest.mark.validates_models
@pytest.mark.parametrize("auth_service, operation_model", AUTH_OPERATIONS)
def test_all_requirements_match_for_operation(auth_service, operation_model):
# Validates that all operation-level signature types have the same requirements
message = f'Found mixed signer requirements for operation: {auth_service}.{operation_model.name}'
auth_config = operation_model.auth
assert_all_requirements_match(auth_config, message)


def assert_all_requirements_match(auth_config, message):
auth_requirements = set(
AUTH_TYPE_REQUIREMENTS[auth_type] for auth_type in auth_config
)
assert len(auth_requirements) == 1, message


def get_config_file_path(base_path, value):
if value is None:
return "file-does-not-exist"

tmp_config_file_path = base_path / "config"
tmp_config_file_path.write_text(
f"[default]\nsigv4a_signing_region_set={value}\n"
)
return tmp_config_file_path


def get_environ_mock(
request,
env_var_value=None,
config_file_value=None,
):
base_path = request.getfixturevalue("tmp_path")
config_file_path = get_config_file_path(base_path, config_file_value)
return {
"AWS_CONFIG_FILE": str(config_file_path),
"AWS_SIGV4A_SIGNING_REGION_SET": env_var_value,
}


@pytest.mark.parametrize(
"client_config, env_var_val, config_file_val, expected",
[
(Config(sigv4a_signing_region_set="foo"), "bar", "baz", "foo"),
(Config(sigv4a_signing_region_set="foo"), None, None, "foo"),
(None, "bar", "baz", "bar"),
(None, None, "baz", "baz"),
(Config(sigv4a_signing_region_set="foo"), None, "baz", "foo"),
(None, None, None, None),
],
)
def test_sigv4a_signing_region_set_config_from_environment(
client_config, env_var_val, config_file_val, expected, request
):
environ_mock = get_environ_mock(request, env_var_val, config_file_val)
with mock.patch('os.environ', environ_mock):
session = create_session()
s3 = session.create_client('s3', config=client_config)
assert s3.meta.config.sigv4a_signing_region_set == expected
Loading

0 comments on commit 9e6b1c4

Please sign in to comment.