diff --git a/connexion/security.py b/connexion/security.py index 2ad092a60..62a10719d 100644 --- a/connexion/security.py +++ b/connexion/security.py @@ -493,11 +493,16 @@ def parse_security_scheme( security_handler = self.security_handlers["apiKey"] return security_handler().get_fn(security_scheme, required_scopes) - # Custom security handler - elif (scheme := security_scheme["scheme"].lower()) in self.security_handlers: + # Custom security scheme handler + elif "scheme" in security_scheme and (scheme := security_scheme["scheme"].lower()) in self.security_handlers: security_handler = self.security_handlers[scheme] return security_handler().get_fn(security_scheme, required_scopes) + # Custom security type handler + elif security_type in self.security_handlers: + security_handler = self.security_handlers[security_type] + return security_handler().get_fn(security_scheme, required_scopes) + else: logger.warning( "... Unsupported security scheme type %s", diff --git a/tests/api/test_secure_api.py b/tests/api/test_secure_api.py index cf547965f..a4225652a 100644 --- a/tests/api/test_secure_api.py +++ b/tests/api/test_secure_api.py @@ -6,6 +6,8 @@ from connexion.exceptions import OAuthProblem from connexion.security import NO_VALUE, BasicSecurityHandler, OAuthSecurityHandler +from tests.conftest import OPENAPI3_SPEC + class FakeResponse: def __init__(self, status_code, text): @@ -288,3 +290,88 @@ def wrapper(request): headers={"Authorization": "my_basic dGVzdDp0ZXN0"}, ) assert res.status_code == 200 + + +def test_security_map_custom_type(secure_api_spec_dir): + def generate_token(scopes): + token_segments = [ + json.dumps({"alg": "none", "typ": "JWT"}), + json.dumps({"sub": "1234567890", "name": "John Doe", "scopes": scopes}), + "", + ] + token = ".".join( + base64.urlsafe_b64encode(s.encode()).decode().rstrip("=") + for s in token_segments + ) + return token + + class FakeOIDCSecurityHandler(OAuthSecurityHandler): + """ + Uses openIdConnect (not currently directly implemented) as auth type to test custom/unimplemented auth types. + Doesn't attempt to actually implement OIDC + """ + + def _get_verify_func( + self, token_info_func, scope_validate_func, required_scopes + ): + check_oauth_func = self.check_oauth_func( + token_info_func, scope_validate_func + ) + + def wrapper(request): + auth_type, token = self.get_auth_header_value(request) + if auth_type != "bearer": + return NO_VALUE + + return check_oauth_func(request, token, required_scopes=required_scopes) + + return wrapper + + def get_tokeninfo_func(self, security_definition: dict): + def wrapper(token): + segments = token.split(".") + body = segments[1] + body += "=" * (-len(body) % 4) + return json.loads(base64.urlsafe_b64decode(body)) + + return wrapper + + security_map = { + "openIdConnect": FakeOIDCSecurityHandler, + } + # api level + app = App(__name__, specification_dir=secure_api_spec_dir) + app.add_api(OPENAPI3_SPEC, security_map=security_map) + app_client = app.test_client() + invalid_token = generate_token(["invalidscope"]) + res = app_client.post( + "/v1.0/greeting_oidc", + headers={"Authorization": f"bearer {invalid_token}"}, + ) + assert res.status_code == 403 + + valid_token = generate_token(["mytestscope"]) + res = app_client.post( + "/v1.0/greeting_oidc", + headers={"Authorization": f"bearer {valid_token}"}, + ) + assert res.status_code == 200 + + # app level + app = App( + __name__, specification_dir=secure_api_spec_dir, security_map=security_map + ) + app.add_api(OPENAPI3_SPEC) + app_client = app.test_client() + + res = app_client.post( + "/v1.0/greeting_oidc", + headers={"Authorization": f"bearer {invalid_token}"}, + ) + assert res.status_code == 403 + + res = app_client.post( + "/v1.0/greeting_oidc", + headers={"Authorization": f"bearer {valid_token}"}, + ) + assert res.status_code == 200 diff --git a/tests/fakeapi/hello/__init__.py b/tests/fakeapi/hello/__init__.py index b6d6c00ac..bac0316b3 100644 --- a/tests/fakeapi/hello/__init__.py +++ b/tests/fakeapi/hello/__init__.py @@ -48,6 +48,11 @@ def post_greeting_basic(): return data +def post_greeting_oidc(): + data = {"greeting": "Hello oidc"} + return data + + def post_greeting3(body, **kwargs): data = {"greeting": "Hello {name}".format(name=body["name"])} return data diff --git a/tests/fixtures/secure_api/openapi.yaml b/tests/fixtures/secure_api/openapi.yaml index a951791db..465d565c5 100644 --- a/tests/fixtures/secure_api/openapi.yaml +++ b/tests/fixtures/secure_api/openapi.yaml @@ -41,6 +41,21 @@ paths: type: object security: - basic: [] + '/greeting_oidc': + post: + summary: Generate greeting + description: Generates a greeting message. + operationId: fakeapi.hello.post_greeting_oidc + responses: + '200': + description: greeting response + content: + '*/*': + schema: + type: object + security: + - openIdConnect: + - mytestscope components: securitySchemes: oauth: @@ -56,3 +71,8 @@ components: scheme: basic description: Basic auth x-basicInfoFunc: fakeapi.auth.fake_basic_auth + + openIdConnect: + type: openIdConnect + description: Fake OIDC auth + openIdConnectUrl: https://oauth.example/.well-known/openid-configuration