Skip to content

Commit

Permalink
Add referral source to cloud on data plane (#3096)
Browse files Browse the repository at this point in the history
* cloud auth referral source

* minor clarity

* k

* minor modification to be best practice

* typing

* Update ReferralSourceSelector.tsx

* Update ReferralSourceSelector.tsx

---------

Co-authored-by: hagen-danswer <[email protected]>
  • Loading branch information
pablonyx and hagen-danswer authored Nov 13, 2024
1 parent fdc4811 commit 22189f0
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 21 deletions.
29 changes: 23 additions & 6 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,17 @@ async def create(
safe: bool = False,
request: Optional[Request] = None,
) -> User:
referral_source = None
if request is not None:
referral_source = request.cookies.get("referral_source", None)

tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
referral_source=referral_source,
)

async with get_async_session_with_tenant(tenant_id) as db_session:
Expand Down Expand Up @@ -294,12 +299,17 @@ async def oauth_callback(
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)

tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
referral_source=referral_source,
)

if not tenant_id:
Expand Down Expand Up @@ -711,8 +721,6 @@ def generate_state_token(


# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91


def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
Expand Down Expand Up @@ -762,15 +770,22 @@ def get_oauth_router(
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request, scopes: List[str] = Query(None)
request: Request,
scopes: List[str] = Query(None),
) -> OAuth2AuthorizeResponse:
referral_source = request.cookies.get("referral_source", None)

if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))

next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {"next_url": next_url}

state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
Expand Down Expand Up @@ -829,8 +844,11 @@ async def callback(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)

next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)

# Authenticate user
request.state.referral_source = referral_source

# Proceed to authenticate or create the user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
Expand Down Expand Up @@ -872,7 +890,6 @@ async def callback(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type

return redirect_response

return router
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def get_application() -> FastAPI:
tags=["users"],
)

if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
Expand Down
27 changes: 27 additions & 0 deletions backend/ee/danswer/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import OpenID

from danswer.auth.users import auth_backend
Expand Down Expand Up @@ -59,6 +60,31 @@ def get_application() -> FastAPI:
if MULTI_TENANT:
add_tenant_id_middleware(application, logger)

if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
create_danswer_oauth_router(
oauth_client,
auth_backend,
USER_AUTH_SECRET,
associate_by_email=True,
is_verified_by_default=True,
# Points the user back to the login page
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)

# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)

if AUTH_TYPE == AuthType.OIDC:
include_router_with_global_prefix_prepended(
application,
Expand All @@ -73,6 +99,7 @@ def get_application() -> FastAPI:
prefix="/auth/oidc",
tags=["auth"],
)

# need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
application,
Expand Down
1 change: 1 addition & 0 deletions backend/ee/danswer/server/tenants/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ class ImpersonateRequest(BaseModel):
class TenantCreationPayload(BaseModel):
tenant_id: str
email: str
referral_source: str | None = None
18 changes: 12 additions & 6 deletions backend/ee/danswer/server/tenants/provisioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
logger = logging.getLogger(__name__)


async def get_or_create_tenant_id(email: str) -> str:
async def get_or_create_tenant_id(
email: str, referral_source: str | None = None
) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
Expand All @@ -51,7 +53,7 @@ async def get_or_create_tenant_id(email: str) -> str:
except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email)
tenant_id = await create_tenant(email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
Expand All @@ -64,13 +66,13 @@ async def get_or_create_tenant_id(email: str) -> str:
return tenant_id


async def create_tenant(email: str) -> str:
async def create_tenant(email: str, referral_source: str | None = None) -> str:
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
await notify_control_plane(tenant_id, email)
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
Expand Down Expand Up @@ -117,14 +119,18 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)


async def notify_control_plane(tenant_id: str, email: str) -> None:
async def notify_control_plane(
tenant_id: str, email: str, referral_source: str | None = None
) -> None:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
payload = TenantCreationPayload(tenant_id=tenant_id, email=email)
payload = TenantCreationPayload(
tenant_id=tenant_id, email=email, referral_source=referral_source
)

async with aiohttp.ClientSession() as session:
async with session.post(
Expand Down
2 changes: 2 additions & 0 deletions backend/tests/integration/common_utils/managers/tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ class TenantManager:
def create(
tenant_id: str | None = None,
initial_admin_email: str | None = None,
referral_source: str | None = None,
) -> dict[str, str]:
body = {
"tenant_id": tenant_id,
"initial_admin_email": initial_admin_email,
"referral_source": referral_source,
}

token = generate_auth_token()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Create Tenant 1 and its Admin User
TenantManager.create("tenant_dev1", "[email protected]")
TenantManager.create("tenant_dev1", "[email protected]", "Data Plane Registration")
test_user1: DATestUser = UserManager.create(name="test1", email="[email protected]")
assert UserManager.verify_role(test_user1, UserRole.ADMIN)

# Create Tenant 2 and its Admin User
TenantManager.create("tenant_dev2", "[email protected]")
TenantManager.create("tenant_dev2", "[email protected]", "Data Plane Registration")
test_user2: DATestUser = UserManager.create(name="test2", email="[email protected]")
assert UserManager.verify_role(test_user2, UserRole.ADMIN)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# Test flow from creating tenant to registering as a user
def test_tenant_creation(reset_multitenant: None) -> None:
TenantManager.create("tenant_dev", "[email protected]")
TenantManager.create("tenant_dev", "[email protected]", "Data Plane Registration")
test_user: DATestUser = UserManager.create(name="test", email="[email protected]")

assert UserManager.verify_role(test_user, UserRole.ADMIN)
Expand Down
8 changes: 7 additions & 1 deletion web/src/app/auth/login/EmailPasswordForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import { Spinner } from "@/components/Spinner";
export function EmailPasswordForm({
isSignup = false,
shouldVerify,
referralSource,
}: {
isSignup?: boolean;
shouldVerify?: boolean;
referralSource?: string;
}) {
const router = useRouter();
const { popup, setPopup } = usePopup();
Expand All @@ -39,7 +41,11 @@ export function EmailPasswordForm({
if (isSignup) {
// login is fast, no need to show a spinner
setIsWorking(true);
const response = await basicSignup(values.email, values.password);
const response = await basicSignup(
values.email,
values.password,
referralSource
);

if (!response.ok) {
const errorDetail = (await response.json()).detail;
Expand Down
8 changes: 6 additions & 2 deletions web/src/app/auth/login/SignInButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ export function SignInButton({
);
}

const url = new URL(authorizeUrl);

const finalAuthorizeUrl = url.toString();

if (!button) {
throw new Error(`Unhandled authType: ${authType}`);
}

return (
<a
className="mx-auto mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
href={authorizeUrl}
className="mx-auto mt-6 py-3 w-full text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
href={finalAuthorizeUrl}
>
{button}
</a>
Expand Down
74 changes: 74 additions & 0 deletions web/src/app/auth/signup/ReferralSourceSelector.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"use client";

import { useState } from "react";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Label } from "@/components/admin/connectors/Field";

interface ReferralSourceSelectorProps {
defaultValue?: string;
}

const ReferralSourceSelector: React.FC<ReferralSourceSelectorProps> = ({
defaultValue,
}) => {
const [referralSource, setReferralSource] = useState(defaultValue);

const referralOptions = [
{ value: "search", label: "Search Engine (Google/Bing)" },
{ value: "friend", label: "Friend/Colleague" },
{ value: "linkedin", label: "LinkedIn" },
{ value: "twitter", label: "Twitter" },
{ value: "hackernews", label: "HackerNews" },
{ value: "reddit", label: "Reddit" },
{ value: "youtube", label: "YouTube" },
{ value: "podcast", label: "Podcast" },
{ value: "blog", label: "Article/Blog" },
{ value: "ads", label: "Advertisements" },
{ value: "other", label: "Other" },
];

const handleChange = (value: string) => {
setReferralSource(value);
const cookies = require("js-cookie");
cookies.set("referral_source", value, {
expires: 365,
path: "/",
sameSite: "strict",
});
};

return (
<div className="w-full max-w-sm gap-y-2 flex flex-col mx-auto">
<Label className="text-text-950" small={false}>
How did you hear about us?
</Label>
<Select value={referralSource} onValueChange={handleChange}>
<SelectTrigger
id="referral-source"
className="w-full border-gray-300 rounded-md shadow-sm focus:border-indigo-500 focus:ring-indigo-500"
>
<SelectValue placeholder="Select an option" />
</SelectTrigger>
<SelectContent className="max-h-60 overflow-y-auto">
{referralOptions.map((option) => (
<SelectItem
key={option.value}
value={option.value}
className="py-2 px-3 hover:bg-indigo-100 cursor-pointer"
>
{option.label}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
);
};

export default ReferralSourceSelector;
9 changes: 9 additions & 0 deletions web/src/app/auth/signup/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import Text from "@/components/ui/text";
import Link from "next/link";
import { SignInButton } from "../login/SignInButton";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import ReferralSourceSelector from "./ReferralSourceSelector";
import { Separator } from "@/components/ui/separator";

const Page = async () => {
// catch cases where the backend is completely unreachable here
Expand Down Expand Up @@ -62,6 +64,13 @@ const Page = async () => {
<h2 className="text-center text-xl text-strong font-bold">
{cloud ? "Complete your sign up" : "Sign Up for Danswer"}
</h2>
{cloud && (
<>
<div className="w-full flex flex-col items-center space-y-4 mb-4 mt-4">
<ReferralSourceSelector />
</div>
</>
)}

{cloud && authUrl && (
<div className="w-full justify-center">
Expand Down
Loading

0 comments on commit 22189f0

Please sign in to comment.