From 22189f02c68190f40159679b410adbb12726a539 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Tue, 12 Nov 2024 16:42:25 -0800 Subject: [PATCH] Add referral source to cloud on data plane (#3096) * cloud auth referral source * minor clarity * k * minor modification to be best practice * typing * Update ReferralSourceSelector.tsx * Update ReferralSourceSelector.tsx --------- Co-authored-by: hagen-danswer --- backend/danswer/auth/users.py | 29 ++++++-- backend/danswer/main.py | 2 +- backend/ee/danswer/main.py | 27 +++++++ backend/ee/danswer/server/tenants/models.py | 1 + .../ee/danswer/server/tenants/provisioning.py | 18 +++-- .../common_utils/managers/tenant.py | 2 + .../syncing/test_search_permissions.py | 4 +- .../tenants/test_tenant_creation.py | 2 +- web/src/app/auth/login/EmailPasswordForm.tsx | 8 +- web/src/app/auth/login/SignInButton.tsx | 8 +- .../auth/signup/ReferralSourceSelector.tsx | 74 +++++++++++++++++++ web/src/app/auth/signup/page.tsx | 9 +++ web/src/lib/user.ts | 7 +- web/src/lib/userSS.ts | 6 +- 14 files changed, 176 insertions(+), 21 deletions(-) create mode 100644 web/src/app/auth/signup/ReferralSourceSelector.tsx diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 73f1ec18484..9804b5a37b7 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -228,12 +228,17 @@ async def create( safe: bool = False, request: Optional[Request] = None, ) -> User: + referral_source = None + if request is not None: + referral_source = request.cookies.get("referral_source", None) + tenant_id = await fetch_ee_implementation_or_noop( "danswer.server.tenants.provisioning", "get_or_create_tenant_id", async_return_default_schema, )( email=user_create.email, + referral_source=referral_source, ) async with get_async_session_with_tenant(tenant_id) as db_session: @@ -294,12 +299,17 @@ async def oauth_callback( associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> models.UOAP: + referral_source = None + if request: + referral_source = getattr(request.state, "referral_source", None) + tenant_id = await fetch_ee_implementation_or_noop( "danswer.server.tenants.provisioning", "get_or_create_tenant_id", async_return_default_schema, )( email=account_email, + referral_source=referral_source, ) if not tenant_id: @@ -711,8 +721,6 @@ def generate_state_token( # refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91 - - def create_danswer_oauth_router( oauth_client: BaseOAuth2, backend: AuthenticationBackend, @@ -762,15 +770,22 @@ def get_oauth_router( response_model=OAuth2AuthorizeResponse, ) async def authorize( - request: Request, scopes: List[str] = Query(None) + request: Request, + scopes: List[str] = Query(None), ) -> OAuth2AuthorizeResponse: + referral_source = request.cookies.get("referral_source", None) + if redirect_url is not None: authorize_redirect_url = redirect_url else: authorize_redirect_url = str(request.url_for(callback_route_name)) next_url = request.query_params.get("next", "/") - state_data: Dict[str, str] = {"next_url": next_url} + + state_data: Dict[str, str] = { + "next_url": next_url, + "referral_source": referral_source or "default_referral", + } state = generate_state_token(state_data, state_secret) authorization_url = await oauth_client.get_authorization_url( authorize_redirect_url, @@ -829,8 +844,11 @@ async def callback( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) next_url = state_data.get("next_url", "/") + referral_source = state_data.get("referral_source", None) - # Authenticate user + request.state.referral_source = referral_source + + # Proceed to authenticate or create the user try: user = await user_manager.oauth_callback( oauth_client.name, @@ -872,7 +890,6 @@ async def callback( redirect_response.status_code = response.status_code if hasattr(response, "media_type"): redirect_response.media_type = response.media_type - return redirect_response return router diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 0aff801c8f4..ff2185dab7c 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -315,7 +315,7 @@ def get_application() -> FastAPI: tags=["users"], ) - if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD: + if AUTH_TYPE == AuthType.GOOGLE_OAUTH: oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) include_router_with_global_prefix_prepended( application, diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index d09a2893d69..96655af2acd 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -1,4 +1,5 @@ from fastapi import FastAPI +from httpx_oauth.clients.google import GoogleOAuth2 from httpx_oauth.clients.openid import OpenID from danswer.auth.users import auth_backend @@ -59,6 +60,31 @@ def get_application() -> FastAPI: if MULTI_TENANT: add_tenant_id_middleware(application, logger) + if AUTH_TYPE == AuthType.CLOUD: + oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) + include_router_with_global_prefix_prepended( + application, + create_danswer_oauth_router( + oauth_client, + auth_backend, + USER_AUTH_SECRET, + associate_by_email=True, + is_verified_by_default=True, + # Points the user back to the login page + redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback", + ), + prefix="/auth/oauth", + tags=["auth"], + ) + + # Need basic auth router for `logout` endpoint + include_router_with_global_prefix_prepended( + application, + fastapi_users.get_logout_router(auth_backend), + prefix="/auth", + tags=["auth"], + ) + if AUTH_TYPE == AuthType.OIDC: include_router_with_global_prefix_prepended( application, @@ -73,6 +99,7 @@ def get_application() -> FastAPI: prefix="/auth/oidc", tags=["auth"], ) + # need basic auth router for `logout` endpoint include_router_with_global_prefix_prepended( application, diff --git a/backend/ee/danswer/server/tenants/models.py b/backend/ee/danswer/server/tenants/models.py index df24ff6c32d..c372418f6a4 100644 --- a/backend/ee/danswer/server/tenants/models.py +++ b/backend/ee/danswer/server/tenants/models.py @@ -38,3 +38,4 @@ class ImpersonateRequest(BaseModel): class TenantCreationPayload(BaseModel): tenant_id: str email: str + referral_source: str | None = None diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index e956cf4359c..32f95b1200d 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -41,7 +41,9 @@ logger = logging.getLogger(__name__) -async def get_or_create_tenant_id(email: str) -> str: +async def get_or_create_tenant_id( + email: str, referral_source: str | None = None +) -> str: """Get existing tenant ID for an email or create a new tenant if none exists.""" if not MULTI_TENANT: return POSTGRES_DEFAULT_SCHEMA @@ -51,7 +53,7 @@ async def get_or_create_tenant_id(email: str) -> str: except exceptions.UserNotExists: # If tenant does not exist and in Multi tenant mode, provision a new tenant try: - tenant_id = await create_tenant(email) + tenant_id = await create_tenant(email, referral_source) except Exception as e: logger.error(f"Tenant provisioning failed: {e}") raise HTTPException(status_code=500, detail="Failed to provision tenant.") @@ -64,13 +66,13 @@ async def get_or_create_tenant_id(email: str) -> str: return tenant_id -async def create_tenant(email: str) -> str: +async def create_tenant(email: str, referral_source: str | None = None) -> str: tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) try: # Provision tenant on data plane await provision_tenant(tenant_id, email) # Notify control plane - await notify_control_plane(tenant_id, email) + await notify_control_plane(tenant_id, email, referral_source) except Exception as e: logger.error(f"Tenant provisioning failed: {e}") await rollback_tenant_provisioning(tenant_id) @@ -117,14 +119,18 @@ async def provision_tenant(tenant_id: str, email: str) -> None: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) -async def notify_control_plane(tenant_id: str, email: str) -> None: +async def notify_control_plane( + tenant_id: str, email: str, referral_source: str | None = None +) -> None: logger.info("Fetching billing information") token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } - payload = TenantCreationPayload(tenant_id=tenant_id, email=email) + payload = TenantCreationPayload( + tenant_id=tenant_id, email=email, referral_source=referral_source + ) async with aiohttp.ClientSession() as session: async with session.post( diff --git a/backend/tests/integration/common_utils/managers/tenant.py b/backend/tests/integration/common_utils/managers/tenant.py index 76fd16471f8..fc411018df7 100644 --- a/backend/tests/integration/common_utils/managers/tenant.py +++ b/backend/tests/integration/common_utils/managers/tenant.py @@ -28,10 +28,12 @@ class TenantManager: def create( tenant_id: str | None = None, initial_admin_email: str | None = None, + referral_source: str | None = None, ) -> dict[str, str]: body = { "tenant_id": tenant_id, "initial_admin_email": initial_admin_email, + "referral_source": referral_source, } token = generate_auth_token() diff --git a/backend/tests/integration/multitenant_tests/syncing/test_search_permissions.py b/backend/tests/integration/multitenant_tests/syncing/test_search_permissions.py index 454b02412d4..fead77387f6 100644 --- a/backend/tests/integration/multitenant_tests/syncing/test_search_permissions.py +++ b/backend/tests/integration/multitenant_tests/syncing/test_search_permissions.py @@ -14,12 +14,12 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None: # Create Tenant 1 and its Admin User - TenantManager.create("tenant_dev1", "test1@test.com") + TenantManager.create("tenant_dev1", "test1@test.com", "Data Plane Registration") test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com") assert UserManager.verify_role(test_user1, UserRole.ADMIN) # Create Tenant 2 and its Admin User - TenantManager.create("tenant_dev2", "test2@test.com") + TenantManager.create("tenant_dev2", "test2@test.com", "Data Plane Registration") test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com") assert UserManager.verify_role(test_user2, UserRole.ADMIN) diff --git a/backend/tests/integration/multitenant_tests/tenants/test_tenant_creation.py b/backend/tests/integration/multitenant_tests/tenants/test_tenant_creation.py index 6088743e317..c2dbcc81790 100644 --- a/backend/tests/integration/multitenant_tests/tenants/test_tenant_creation.py +++ b/backend/tests/integration/multitenant_tests/tenants/test_tenant_creation.py @@ -11,7 +11,7 @@ # Test flow from creating tenant to registering as a user def test_tenant_creation(reset_multitenant: None) -> None: - TenantManager.create("tenant_dev", "test@test.com") + TenantManager.create("tenant_dev", "test@test.com", "Data Plane Registration") test_user: DATestUser = UserManager.create(name="test", email="test@test.com") assert UserManager.verify_role(test_user, UserRole.ADMIN) diff --git a/web/src/app/auth/login/EmailPasswordForm.tsx b/web/src/app/auth/login/EmailPasswordForm.tsx index 334c74d14f7..06053fae5f2 100644 --- a/web/src/app/auth/login/EmailPasswordForm.tsx +++ b/web/src/app/auth/login/EmailPasswordForm.tsx @@ -14,9 +14,11 @@ import { Spinner } from "@/components/Spinner"; export function EmailPasswordForm({ isSignup = false, shouldVerify, + referralSource, }: { isSignup?: boolean; shouldVerify?: boolean; + referralSource?: string; }) { const router = useRouter(); const { popup, setPopup } = usePopup(); @@ -39,7 +41,11 @@ export function EmailPasswordForm({ if (isSignup) { // login is fast, no need to show a spinner setIsWorking(true); - const response = await basicSignup(values.email, values.password); + const response = await basicSignup( + values.email, + values.password, + referralSource + ); if (!response.ok) { const errorDetail = (await response.json()).detail; diff --git a/web/src/app/auth/login/SignInButton.tsx b/web/src/app/auth/login/SignInButton.tsx index 128f5790c6e..b06f9bad79b 100644 --- a/web/src/app/auth/login/SignInButton.tsx +++ b/web/src/app/auth/login/SignInButton.tsx @@ -36,14 +36,18 @@ export function SignInButton({ ); } + const url = new URL(authorizeUrl); + + const finalAuthorizeUrl = url.toString(); + if (!button) { throw new Error(`Unhandled authType: ${authType}`); } return ( {button} diff --git a/web/src/app/auth/signup/ReferralSourceSelector.tsx b/web/src/app/auth/signup/ReferralSourceSelector.tsx new file mode 100644 index 00000000000..5f2acd9fbc2 --- /dev/null +++ b/web/src/app/auth/signup/ReferralSourceSelector.tsx @@ -0,0 +1,74 @@ +"use client"; + +import { useState } from "react"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Label } from "@/components/admin/connectors/Field"; + +interface ReferralSourceSelectorProps { + defaultValue?: string; +} + +const ReferralSourceSelector: React.FC = ({ + defaultValue, +}) => { + const [referralSource, setReferralSource] = useState(defaultValue); + + const referralOptions = [ + { value: "search", label: "Search Engine (Google/Bing)" }, + { value: "friend", label: "Friend/Colleague" }, + { value: "linkedin", label: "LinkedIn" }, + { value: "twitter", label: "Twitter" }, + { value: "hackernews", label: "HackerNews" }, + { value: "reddit", label: "Reddit" }, + { value: "youtube", label: "YouTube" }, + { value: "podcast", label: "Podcast" }, + { value: "blog", label: "Article/Blog" }, + { value: "ads", label: "Advertisements" }, + { value: "other", label: "Other" }, + ]; + + const handleChange = (value: string) => { + setReferralSource(value); + const cookies = require("js-cookie"); + cookies.set("referral_source", value, { + expires: 365, + path: "/", + sameSite: "strict", + }); + }; + + return ( +
+ + +
+ ); +}; + +export default ReferralSourceSelector; diff --git a/web/src/app/auth/signup/page.tsx b/web/src/app/auth/signup/page.tsx index 29c2f97ea16..223faff331d 100644 --- a/web/src/app/auth/signup/page.tsx +++ b/web/src/app/auth/signup/page.tsx @@ -12,6 +12,8 @@ import Text from "@/components/ui/text"; import Link from "next/link"; import { SignInButton } from "../login/SignInButton"; import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; +import ReferralSourceSelector from "./ReferralSourceSelector"; +import { Separator } from "@/components/ui/separator"; const Page = async () => { // catch cases where the backend is completely unreachable here @@ -62,6 +64,13 @@ const Page = async () => {

{cloud ? "Complete your sign up" : "Sign Up for Danswer"}

+ {cloud && ( + <> +
+ +
+ + )} {cloud && authUrl && (
diff --git a/web/src/lib/user.ts b/web/src/lib/user.ts index 10d426c1ec9..d540366f519 100644 --- a/web/src/lib/user.ts +++ b/web/src/lib/user.ts @@ -43,7 +43,11 @@ export const basicLogin = async ( return response; }; -export const basicSignup = async (email: string, password: string) => { +export const basicSignup = async ( + email: string, + password: string, + referralSource?: string +) => { const response = await fetch("/api/auth/register", { method: "POST", credentials: "include", @@ -54,6 +58,7 @@ export const basicSignup = async (email: string, password: string) => { email, username: email, password, + referral_source: referralSource, }), }); return response; diff --git a/web/src/lib/userSS.ts b/web/src/lib/userSS.ts index 55b916d34b3..906f23fa8b2 100644 --- a/web/src/lib/userSS.ts +++ b/web/src/lib/userSS.ts @@ -63,7 +63,11 @@ const getOIDCAuthUrlSS = async (nextUrl: string | null): Promise => { }; const getGoogleOAuthUrlSS = async (): Promise => { - const res = await fetch(buildUrl(`/auth/oauth/authorize`)); + const res = await fetch(buildUrl(`/auth/oauth/authorize`), { + headers: { + cookie: processCookies(await cookies()), + }, + }); if (!res.ok) { throw new Error("Failed to fetch data"); }