Skip to content

Commit

Permalink
add auth
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Oct 5, 2024
1 parent 440dee0 commit f6ad21e
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 31 deletions.
23 changes: 14 additions & 9 deletions backend/danswer/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from textwrap import dedent
from typing import Any

from danswer.configs.app_configs import SMTP_PASS
Expand Down Expand Up @@ -58,23 +59,27 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
def send_user_email_invite(user_email: str, current_user: User) -> None:
msg = MIMEMultipart()
msg["Subject"] = "Invitation to Join Danswer Workspace"
msg["To"] = user_email
msg["From"] = current_user.email
msg["To"] = user_email

email_body = f"""
Hello,
email_body = dedent(
f"""\
Hello,
You have been invited to join a workspace on Danswer.
You have been invited to join a workspace on Danswer.
To join the workspace, please do so at the following link:
{WEB_DOMAIN}/auth/login
To join the workspace, please visit the following link:
Best regards,
The Danswer Team"""
{WEB_DOMAIN}/auth/login
msg.attach(MIMEText(email_body, "plain"))
Best regards,
The Danswer Team
"""
)

msg.attach(MIMEText(email_body, "plain"))
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp_server:
smtp_server.starttls()
smtp_server.login(SMTP_USER, SMTP_PASS)
smtp_server.send_message(msg)
print(f"Invitation email sent to {user_email}.")
22 changes: 22 additions & 0 deletions backend/ee/danswer/server/tenants/access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
from datetime import datetime
from datetime import timedelta

import jwt

DATA_PLANE_SECRET = os.getenv("DATA_PLANE_SECRET")
ALGORITHM = "HS256"


def generate_data_plane_token() -> str:
if DATA_PLANE_SECRET is None:
raise ValueError("DATA_PLANE_SECRET is not set")

payload = {
"iss": "data_plane",
"exp": datetime.utcnow() + timedelta(minutes=5),
"iat": datetime.utcnow(),
"scope": "api_access",
}
token = jwt.encode(payload, DATA_PLANE_SECRET, algorithm=ALGORITHM)
return token
3 changes: 1 addition & 2 deletions backend/ee/danswer/server/tenants/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,12 @@ async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation:
logger.info("Fetching billing information")
return fetch_billing_information(current_tenant_id.get())
return BillingInformation(**fetch_billing_information(current_tenant_id.get()))


@router.post("/create-customer-portal-session")
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
try:
logger.info("test")
# Fetch tenant_id and the current tenant's information
tenant_id = current_tenant_id.get()
stripe_info = fetch_tenant_stripe_information(tenant_id)
Expand Down
32 changes: 21 additions & 11 deletions backend/ee/danswer/server/tenants/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,35 @@

from danswer.configs.app_configs import CONTROLPLANE_API_URL
from danswer.utils.logger import setup_logger
from ee.danswer.server.tenants.models import BillingInformation
from ee.danswer.server.tenants.access import generate_data_plane_token

logger = setup_logger()


def fetch_tenant_stripe_information(tenant_id: str) -> dict:
response = requests.get(
f"{CONTROLPLANE_API_URL}/tenant-stripe-information?tenant_id={tenant_id}",
)
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json", # Include if sending JSON data
}
url = f"{CONTROLPLANE_API_URL}/tenant-stripe-information"
params = {"tenant_id": tenant_id} # Use params for query parameters
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
return response.json()


def fetch_billing_information(tenant_id: str) -> BillingInformation:
def fetch_billing_information(tenant_id: str) -> dict:
logger.info("Fetching billing information")
response = requests.get(
f"{CONTROLPLANE_API_URL}/billing-information?tenant_id={tenant_id}",
)
logger.info("Billing information fetched", response.json())

token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROLPLANE_API_URL}/billing-information"
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
return response.json()
billing_info = response.json()
logger.info("Billing information fetched", billing_info)
return billing_info
21 changes: 12 additions & 9 deletions web/src/app/ee/admin/cloud-settings/BillingInformationPage.tsx
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
"use client";

import {
CreditCard,
ArrowUp,
ArrowUpLeft,
ArrowFatUp,
} from "@phosphor-icons/react";
import { CreditCard, ArrowFatUp } from "@phosphor-icons/react";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { loadStripe, Stripe } from "@stripe/stripe-js";
import { loadStripe } from "@stripe/stripe-js";
import { usePopup } from "@/components/admin/connectors/Popup";
import { SettingsIcon } from "@/components/icons/icons";
import {
Expand All @@ -20,7 +15,6 @@ import {
import { useEffect } from "react";

export default function BillingInformationPage() {
const [seats, setSeats] = useState(1);
const router = useRouter();
const { popup, setPopup } = usePopup();
const stripePromise = loadStripe(
Expand All @@ -34,6 +28,16 @@ export default function BillingInformationPage() {
refreshBillingInformation,
} = useBillingInformation();

const [seats, setSeats] = useState<number | undefined>(
billingInformation?.seats
);

useEffect(() => {
if (billingInformation?.seats) {
setSeats(billingInformation.seats);
}
}, [billingInformation?.seats]);

if (error) {
console.error("Failed to fetch billing information:", error);
}
Expand Down Expand Up @@ -101,7 +105,6 @@ export default function BillingInformationPage() {
throw new Error("No portal URL returned from the server");
}

// Redirect to the Stripe Customer Portal
router.push(url);
} catch (error) {
console.error("Error creating customer portal session:", error);
Expand Down

0 comments on commit f6ad21e

Please sign in to comment.