Skip to content

Commit

Permalink
rework pkgpush auth (#2170)
Browse files Browse the repository at this point in the history
  • Loading branch information
sir-sigurd authored Apr 16, 2021
1 parent de6bfcb commit c2529b2
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 40 deletions.
36 changes: 19 additions & 17 deletions lambdas/pkgpush/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from http import HTTPStatus

import boto3
import requests
from botocore.exceptions import ClientError
from jsonschema import Draft7Validator

Expand All @@ -17,11 +16,9 @@
from quilt3.backends import get_package_registry
from quilt3.backends.s3 import S3PackageRegistryV1
from quilt3.util import PhysicalKey
from t4_lambda_shared.decorator import api
from t4_lambda_shared.decorator import ELBRequest, api
from t4_lambda_shared.utils import get_default_origins, make_json_response

AUTH_ENDPOINT = os.environ['AUTH_ENDPOINT']

PROMOTE_PKG_MAX_MANIFEST_SIZE = int(os.environ['PROMOTE_PKG_MAX_MANIFEST_SIZE'])
PROMOTE_PKG_MAX_PKG_SIZE = int(os.environ['PROMOTE_PKG_MAX_PKG_SIZE'])
PROMOTE_PKG_MAX_FILES = int(os.environ['PROMOTE_PKG_MAX_FILES'])
Expand Down Expand Up @@ -128,14 +125,22 @@
quilt3.data_transfer.S3ClientProvider.get_boto_session = staticmethod(lambda: user_boto_session)


def get_user_credentials(token):
resp = requests.get(AUTH_ENDPOINT, headers={'Authorization': token})
creds = resp.json()
return {
'aws_access_key_id': creds['AccessKeyId'],
'aws_secret_access_key': creds['SecretAccessKey'],
'aws_session_token': creds['SessionToken'],
def get_user_credentials(request):
attrs_map = (
('access_key', 'aws_access_key_id'),
('secret_key', 'aws_secret_access_key'),
('session_token', 'aws_session_token'),
)
creds = {
dst: request.args.get(src)
for src, dst in attrs_map
}
if not all(creds.values()):
raise ApiException(
HTTPStatus.BAD_REQUEST,
f'{", ".join(dict(attrs_map))} are required.'
)
return creds


# Isolated for test-ability.
Expand All @@ -155,10 +160,7 @@ def setup_user_boto_session(session):
def auth(f):
@functools.wraps(f)
def wrapper(request):
auth_header = request.headers.get('authorization')
if not auth_header:
return HTTPStatus.UNAUTHORIZED, '', {}
with setup_user_boto_session(get_user_boto_session(**get_user_credentials(auth_header))):
with setup_user_boto_session(get_user_boto_session(**get_user_credentials(request))):
return f(request)
return wrapper

Expand Down Expand Up @@ -292,7 +294,7 @@ def _push_pkg_to_successor(data, *, get_src, get_dst, get_name, get_pkg, pkg_max
raise ApiException(HTTPStatus.FORBIDDEN, e.message)


@api(cors_origins=get_default_origins())
@api(cors_origins=get_default_origins(), request_class=ELBRequest)
@api_exception_handler
@auth
@json_api(PACKAGE_PROMOTE_SCHEMA)
Expand Down Expand Up @@ -334,7 +336,7 @@ def get_pkg(src_registry, data):
)


@api(cors_origins=get_default_origins())
@api(cors_origins=get_default_origins(), request_class=ELBRequest)
@api_exception_handler
@auth
@json_api(PKG_FROM_FOLDER_SCHEMA)
Expand Down
43 changes: 22 additions & 21 deletions lambdas/pkgpush/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@


class PackagePromoteTestBase(unittest.TestCase):
credentials = {
'access_key': mock.sentinel.TEST_ACCESS_KEY,
'secret_key': mock.sentinel.TEST_SECRET_KEY,
'session_token': mock.sentinel.TEST_SESSION_TOKEN,
}
handler = staticmethod(index.promote_package)
parent_bucket = 'parent-bucket'
src_registry = f's3://{parent_bucket}'
Expand Down Expand Up @@ -102,7 +107,6 @@ def setUpClass(cls):
def setUp(self):
super().setUp()
self.headers = {
'authorization': mock.sentinel.AUTH_TOKEN,
'content-type': 'application/json',
}
self.s3_stubber = Stubber(boto3.client('s3'))
Expand Down Expand Up @@ -140,23 +144,23 @@ def side_effect(registry_url):
yield

@classmethod
def _make_event(cls, body, headers=None):
def _make_event(cls, body, credentials):
return {
'httpMethod': 'POST',
'path': '/foo',
'pathParameters': {},
'queryStringParameters': None,
'headers': headers or None,
'queryStringParameters': credentials,
'headers': None,
'body': body,
'isBase64Encoded': False,
}

def make_request_base(self, params, *, headers):
def make_request_base(self, params, *, credentials):
# This is a function before it get wrapped with @api decorator.
# FIXME: find a cleaner way for this.
response = self.handler.__wrapped__(
Request(
self._make_event(json.dumps(params), headers=headers),
self._make_event(json.dumps(params), credentials=credentials),
)
)
status, body, headers = response
Expand All @@ -166,20 +170,14 @@ def make_request_base(self, params, *, headers):
@mock.patch('time.time', mock.MagicMock(return_value=mock_timestamp))
def make_request(self, *args, headers=None, **kwargs):
self.get_user_boto_session_mock.reset_mock()
get_user_credentials_patcher = mock.patch(
'index.get_user_credentials',
return_value={
'aws_access_key_id': mock.sentinel.USER_ACCESS_KEY,
'aws_secret_access_key': mock.sentinel.USER_SECRET_ACCESS_KEY,
'aws_session_token': mock.sentinel.USER_SESSION_TOKEN,
}
)
with get_user_credentials_patcher as get_user_credentials_mock, \
mock.patch('quilt3.telemetry.reset_session_id') as reset_session_id_mock:
response = self.make_request_base(*args, headers=headers or self.headers, **kwargs)
with mock.patch('quilt3.telemetry.reset_session_id') as reset_session_id_mock:
response = self.make_request_base(*args, credentials=self.credentials, **kwargs)

get_user_credentials_mock.assert_called_once_with(mock.sentinel.AUTH_TOKEN)
self.get_user_boto_session_mock.assert_called_once_with(**get_user_credentials_mock.return_value)
self.get_user_boto_session_mock.assert_called_once_with(
aws_access_key_id=mock.sentinel.TEST_ACCESS_KEY,
aws_secret_access_key=mock.sentinel.TEST_SECRET_KEY,
aws_session_token=mock.sentinel.TEST_SESSION_TOKEN,
)
reset_session_id_mock.assert_called_once_with()

return response
Expand Down Expand Up @@ -327,8 +325,11 @@ def test(self):
)

def test_no_auth(self):
resp = self.make_request_base({}, headers={})
assert (resp.status_code, resp.data) == (HTTPStatus.UNAUTHORIZED, b'')
resp = self.make_request_base({}, credentials={})
assert (resp.status_code, resp.data) == (
HTTPStatus.BAD_REQUEST,
b'{"message": "access_key, secret_key, session_token are required."}'
)

@mock.patch('quilt3.workflows.validate', lambda *args, **kwargs: None)
def test_dst_is_not_successor(self):
Expand Down
19 changes: 17 additions & 2 deletions lambdas/shared/t4_lambda_shared/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import gzip
import traceback
import urllib.parse
from base64 import b64decode, b64encode
from functools import wraps

Expand Down Expand Up @@ -33,11 +34,25 @@ def __init__(self, event):
self.data = event['body']


def api(cors_origins=()):
class ELBRequest(Request):
def __init__(self, event):
super().__init__(event)
# ELB pass queryStringParameters escaped.
self.args = dict(
urllib.parse.parse_qsl(
'&'.join(
f'{k}={v}'
for k, v in self.args.items()
)
)
)


def api(cors_origins=(), *, request_class=Request):
def innerdec(f):
@wraps(f)
def wrapper(event, _):
request = Request(event)
request = request_class(event)
if request.method == 'OPTIONS':
status = 200
response_headers = {}
Expand Down

0 comments on commit c2529b2

Please sign in to comment.