Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: CORS Support in start-api #4991

Merged
merged 9 commits into from
May 18, 2023
26 changes: 26 additions & 0 deletions samcli/local/apigw/local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def _request_handler(self, **kwargs):

route: Route = self._get_current_route(request)
cors_headers = Cors.cors_to_headers(self.api.cors)
cors_headers = self._response_cors_headers(request, cors_headers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating this PR! Did we need to restrict this behaviour to only run for HTTP APIs? From the docs and the linked issue, this seems to be something that the Lambda function deals with on for REST, but we would need to manage for HTTP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @lucashuy

During the implementation, I referred to MDN web docs for reference:

image

I also referred to the implementation of middy cors middleware

below piece of code will basically parse and return a single origin if matched:
image

it is explained if looking deeper into the getOrigin() function:

image

Hope these help explain the reasoning behind my code.

Thanks,

lambda_authorizer = route.authorizer_object

# payloadFormatVersion can only support 2 values: "1.0" and "2.0"
Expand Down Expand Up @@ -800,6 +801,31 @@ def _get_current_route(self, flask_request):

return route

@staticmethod
def _response_cors_headers(flask_request, cors_headers):
if "Access-Control-Allow-Origin" not in cors_headers:
return cors_headers

cors_origins = cors_headers["Access-Control-Allow-Origin"]
# unset this header due to restrictive manner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why unset this? It looks like it should be removed in the case that response_allowed_origin is not set. If so, I would move this to be closer to line 437. I think it just make it a little easier to read since the mutation of the cors_headers is all in one spot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jfuss , thank you for taking your time to review this

we unset and set again if a match is found (restrictive manner)

if multiple domains are allowed, we only send back 1 allowed domain in our response header
if all domains are allowed, we also only send back 1 allowed domain in our response header

we do not return this header (implying a deny) if no matches is found

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @jfuss please also refer to my below comment:
#4991 (comment)

del cors_headers["Access-Control-Allow-Origin"]

incoming_origin = flask_request.headers.get("Origin")
# Restrictive manner: do not allow any origin by default
response_allowed_origin = None
if incoming_origin:
if cors_origins == "*" and cors_headers.get("Access-Control-Allow-Credentials") is True:
response_allowed_origin = incoming_origin
else:
cors_origins_arr = cors_origins.split(",")
if incoming_origin in cors_origins_arr:
response_allowed_origin = incoming_origin

if response_allowed_origin:
cors_headers["Access-Control-Allow-Origin"] = response_allowed_origin

return cors_headers

@staticmethod
def get_request_methods_endpoints(flask_request):
"""
Expand Down
134 changes: 134 additions & 0 deletions tests/unit/local/apigw/test_local_apigw_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1835,6 +1835,140 @@ def test_lambda_output_json_object_no_status_code(self):
self.assertEqual(body, lambda_output)


class TestService_cors_response_headers(TestCase):
def test_response_cors_no_origin(self):
request_mock = Mock()
headers_mock = Mock()
headers_mock.keys.return_value = []
request_mock.headers = headers_mock
request_mock.scheme = "http"

cors = Cors(allow_origin="*", allow_methods="GET,POST,OPTIONS")

response_cors_headers = Cors.cors_to_headers(cors)
response_cors_headers = LocalApigwService._response_cors_headers(request_mock, response_cors_headers)

self.assertTrue("Access-Control-Allow-Origin" not in response_cors_headers)

def test_response_cors_with_origin(self):
incoming_origin = "localhost:3000"

request_mock = Mock()
headers_mock = Mock()
headers_mock.keys.return_value = []
headers_mock.keys.return_value = ["Origin", "Content-Type"]
headers_mock.get.side_effect = [incoming_origin, "application/json"]
headers_mock.getlist.side_effect = [[incoming_origin], ["application/json"]]
request_mock.headers = headers_mock
request_mock.scheme = "http"

cors = Cors(allow_origin="*", allow_methods="GET,POST,OPTIONS", allow_credentials=True)

response_cors_headers = Cors.cors_to_headers(cors)
response_cors_headers = LocalApigwService._response_cors_headers(request_mock, response_cors_headers)

self.assertEqual(incoming_origin, response_cors_headers["Access-Control-Allow-Origin"])

def test_response_cors_with_origin_single_domain(self):
incoming_origin = "localhost:3000"

request_mock = Mock()
headers_mock = Mock()
headers_mock.keys.return_value = []
headers_mock.keys.return_value = ["Origin", "Content-Type"]
headers_mock.get.side_effect = [incoming_origin, "application/json"]
headers_mock.getlist.side_effect = [[incoming_origin], ["application/json"]]
request_mock.headers = headers_mock
request_mock.scheme = "http"

cors = Cors(allow_origin="localhost:3000", allow_methods="GET,POST,OPTIONS", allow_credentials=True)

response_cors_headers = Cors.cors_to_headers(cors)
response_cors_headers = LocalApigwService._response_cors_headers(request_mock, response_cors_headers)

self.assertEqual(incoming_origin, response_cors_headers["Access-Control-Allow-Origin"])

def test_response_cors_with_origin_multi_domains(self):
incoming_origin = "localhost:3000"

request_mock = Mock()
headers_mock = Mock()
headers_mock.keys.return_value = []
headers_mock.keys.return_value = ["Origin", "Content-Type"]
headers_mock.get.side_effect = [incoming_origin, "application/json"]
headers_mock.getlist.side_effect = [[incoming_origin], ["application/json"]]
request_mock.headers = headers_mock
request_mock.scheme = "http"

cors = Cors(
allow_origin="localhost:3000,localhost:6000", allow_methods="GET,POST,OPTIONS", allow_credentials=True
)

response_cors_headers = Cors.cors_to_headers(cors)
response_cors_headers = LocalApigwService._response_cors_headers(request_mock, response_cors_headers)

self.assertEqual(incoming_origin, response_cors_headers["Access-Control-Allow-Origin"])

def test_response_cors_with_origin_multi_domains_not_matching(self):
incoming_origin = "localhost:3000"

request_mock = Mock()
headers_mock = Mock()
headers_mock.keys.return_value = []
headers_mock.keys.return_value = ["Origin", "Content-Type"]
headers_mock.get.side_effect = [incoming_origin, "application/json"]
headers_mock.getlist.side_effect = [[incoming_origin], ["application/json"]]
request_mock.headers = headers_mock
request_mock.scheme = "http"

cors = Cors(
allow_origin="localhost:4000,localhost:6000", allow_methods="GET,POST,OPTIONS", allow_credentials=True
)

response_cors_headers = Cors.cors_to_headers(cors)
response_cors_headers = LocalApigwService._response_cors_headers(request_mock, response_cors_headers)

self.assertTrue("Access-Control-Allow-Origin" not in response_cors_headers)

def test_response_cors_not_allow_credentials(self):
incoming_origin = "localhost:3000"

request_mock = Mock()
headers_mock = Mock()
headers_mock.keys.return_value = []
headers_mock.keys.return_value = ["Origin", "Content-Type"]
headers_mock.get.side_effect = [incoming_origin, "application/json"]
headers_mock.getlist.side_effect = [[incoming_origin], ["application/json"]]
request_mock.headers = headers_mock
request_mock.scheme = "http"

cors = Cors(allow_origin="*", allow_methods="GET,POST,OPTIONS", allow_credentials=False)

response_cors_headers = Cors.cors_to_headers(cors)
response_cors_headers = LocalApigwService._response_cors_headers(request_mock, response_cors_headers)

self.assertTrue("Access-Control-Allow-Origin" not in response_cors_headers)

def test_response_cors_missing_allow_credentials(self):
incoming_origin = "localhost:3000"

request_mock = Mock()
headers_mock = Mock()
headers_mock.keys.return_value = []
headers_mock.keys.return_value = ["Origin", "Content-Type"]
headers_mock.get.side_effect = [incoming_origin, "application/json"]
headers_mock.getlist.side_effect = [[incoming_origin], ["application/json"]]
request_mock.headers = headers_mock
request_mock.scheme = "http"

cors = Cors(allow_origin="*", allow_methods="GET,POST,OPTIONS")

response_cors_headers = Cors.cors_to_headers(cors)
response_cors_headers = LocalApigwService._response_cors_headers(request_mock, response_cors_headers)

self.assertTrue("Access-Control-Allow-Origin" not in response_cors_headers)


class TestServiceCorsToHeaders(TestCase):
def test_basic_conversion(self):
cors = Cors(
Expand Down