Skip to content

Commit

Permalink
OAuth prompt configurable through environment (#1456)
Browse files Browse the repository at this point in the history
Added environment variables `OAUTH_<PROVIDER>_PROMPT` and `OAUTH_PROMPT` to
override oauth prompt parameter, enabling users to explicitly enable login/consent prompts for oauth providers (e.g. `OAUTH_PROMPT=consent` to enable changing users/logging out).
  • Loading branch information
dokterbob committed Oct 22, 2024
1 parent 2162055 commit 45a3865
Showing 1 changed file with 43 additions and 9 deletions.
52 changes: 43 additions & 9 deletions backend/chainlit/oauth_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class OAuthProvider:
client_secret: str
authorize_url: str
authorize_params: Dict[str, str]
default_prompt: Optional[str] = None

def is_configured(self):
return all([os.environ.get(env) for env in self.env])
Expand All @@ -26,6 +27,21 @@ async def get_token(self, code: str, url: str) -> str:
async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]:
raise NotImplementedError()

def get_env_prefix(self) -> str:
"""Return environment prefix, like AZURE_AD."""

return self.id.replace("-", "_").upper()

def get_prompt(self) -> Optional[str]:
"""Return OAuth prompt param."""
if prompt := os.environ.get(f"OAUTH_{self.get_env_prefix()}_PROMPT"):
return prompt

if prompt := os.environ.get("OAUTH_PROMPT"):
return prompt

return self.default_prompt


class GithubOAuthProvider(OAuthProvider):
id = "github"
Expand All @@ -37,9 +53,11 @@ def __init__(self):
self.client_secret = os.environ.get("OAUTH_GITHUB_CLIENT_SECRET")
self.authorize_params = {
"scope": "user:email",
"prompt": "consent",
}

if prompt := self.get_prompt():
self.authorize_params["prompt"] = prompt

async def get_token(self, code: str, url: str):
payload = {
"client_id": self.client_id,
Expand Down Expand Up @@ -96,9 +114,11 @@ def __init__(self):
"scope": "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
"response_type": "code",
"access_type": "offline",
"prompt": "login",
}

if prompt := self.get_prompt():
self.authorize_params["prompt"] = prompt

async def get_token(self, code: str, url: str):
payload = {
"client_id": self.client_id,
Expand Down Expand Up @@ -164,9 +184,11 @@ def __init__(self):
"response_type": "code",
"scope": "https://graph.microsoft.com/User.Read",
"response_mode": "query",
"prompt": "login",
}

if prompt := self.get_prompt():
self.authorize_params["prompt"] = prompt

async def get_token(self, code: str, url: str):
payload = {
"client_id": self.client_id,
Expand Down Expand Up @@ -249,9 +271,11 @@ def __init__(self):
"scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid",
"response_mode": "form_post",
"nonce": nonce,
"prompt": "login",
}

if prompt := self.get_prompt():
self.authorize_params["prompt"] = prompt

async def get_token(self, code: str, url: str):
payload = {
"client_id": self.client_id,
Expand Down Expand Up @@ -329,9 +353,11 @@ def __init__(self):
"response_type": "code",
"scope": "openid profile email",
"response_mode": "query",
"prompt": "login",
}

if prompt := self.get_prompt():
self.authorize_params["prompt"] = prompt

def get_authorization_server_path(self):
if not self.authorization_server_id:
return "/default"
Expand Down Expand Up @@ -401,9 +427,11 @@ def __init__(self):
"response_type": "code",
"scope": "openid profile email",
"audience": f"{self.original_domain}/userinfo",
"prompt": "login",
}

if prompt := self.get_prompt():
self.authorize_params["prompt"] = prompt

async def get_token(self, code: str, url: str):
payload = {
"client_id": self.client_id,
Expand Down Expand Up @@ -459,9 +487,11 @@ def __init__(self):
"response_type": "code",
"scope": "openid profile email",
"audience": f"{self.domain}/userinfo",
"prompt": "login",
}

if prompt := self.get_prompt():
self.authorize_params["prompt"] = prompt

async def get_token(self, code: str, url: str):
payload = {
"client_id": self.client_id,
Expand Down Expand Up @@ -518,9 +548,11 @@ def __init__(self):
"response_type": "code",
"client_id": self.client_id,
"scope": "openid profile email",
"prompt": "login",
}

if prompt := self.get_prompt():
self.authorize_params["prompt"] = prompt

async def get_token(self, code: str, url: str):
payload = {
"client_id": self.client_id,
Expand Down Expand Up @@ -587,9 +619,11 @@ def __init__(self):
self.authorize_params = {
"scope": "openid profile email",
"response_type": "code",
"prompt": "login",
}

if prompt := self.get_prompt():
self.authorize_params["prompt"] = prompt

async def get_token(self, code: str, url: str):
payload = {
"client_id": self.client_id,
Expand Down

0 comments on commit 45a3865

Please sign in to comment.