From 8e739b3e69eb7f96d5808314631f69f553b9fe6f Mon Sep 17 00:00:00 2001 From: Geoffrey Cleaves Date: Mon, 25 Nov 2024 13:43:10 +0100 Subject: [PATCH] Update oauth_providers.py to include Keycloak (#1525) --- backend/chainlit/oauth_providers.py | 66 +++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py index 82b09eeef2..744ffff621 100644 --- a/backend/chainlit/oauth_providers.py +++ b/backend/chainlit/oauth_providers.py @@ -665,6 +665,71 @@ async def get_user_info(self, token: str): return (gitlab_user, user) +class KeycloakOAuthProvider(OAuthProvider): + env = [ + "OAUTH_KEYCLOAK_CLIENT_ID", + "OAUTH_KEYCLOAK_CLIENT_SECRET", + "OAUTH_KEYCLOAK_REALM", + "OAUTH_KEYCLOAK_BASE_URL", + ] + id = os.environ.get("OAUTH_KEYCLOAK_NAME", "keycloak") + + def __init__(self): + self.client_id = os.environ.get("OAUTH_KEYCLOAK_CLIENT_ID") + self.client_secret = os.environ.get("OAUTH_KEYCLOAK_CLIENT_SECRET") + self.realm = os.environ.get("OAUTH_KEYCLOAK_REALM") + self.base_url = os.environ.get("OAUTH_KEYCLOAK_BASE_URL") + self.authorize_url = ( + f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/auth" + ) + + self.authorize_params = { + "scope": "profile email openid", + "response_type": "code", + } + + if prompt := self.get_prompt(): + self.authorize_params["prompt"] = prompt + + async def get_token(self, code: str, url: str): + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": url, + } + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token", + data=payload, + ) + response.raise_for_status() + json = response.json() + token = json.get("access_token") + if not token: + raise httpx.HTTPStatusError( + "Failed to get the access token", + request=response.request, + response=response, + ) + return token + + async def get_user_info(self, token: str): + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/userinfo", + headers={"Authorization": f"Bearer {token}"}, + ) + response.raise_for_status() + kc_user = response.json() + user = User( + identifier=kc_user["email"], + metadata={"provider": "keycloak"}, + ) + return (kc_user, user) + + providers = [ GithubOAuthProvider(), GoogleOAuthProvider(), @@ -675,6 +740,7 @@ async def get_user_info(self, token: str): DescopeOAuthProvider(), AWSCognitoOAuthProvider(), GitlabOAuthProvider(), + KeycloakOAuthProvider(), ]