Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Nov 4, 2024
1 parent e79aa7d commit a23004a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 43 deletions.
48 changes: 5 additions & 43 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.server.tenants.provisioning import TenantProvisioningService
from ee.danswer.server.tenants.provisioning import get_or_create_tenant_id
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
Expand Down Expand Up @@ -239,29 +239,7 @@ async def create(
safe: bool = False,
request: Optional[Request] = None,
) -> User:
if MULTI_TENANT:
try:
tenant_id = get_tenant_id_for_email(user_create.email)

except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
tenant_provisioning_service = TenantProvisioningService()
try:
tenant_id = await tenant_provisioning_service.provision_tenant(
user_create.email
)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(
status_code=500, detail="Failed to provision tenant."
)

if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
else:
tenant_id = POSTGRES_DEFAULT_SCHEMA
tenant_id = get_or_create_tenant_id(user_create.email)

async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
Expand Down Expand Up @@ -303,7 +281,8 @@ async def create(
else:
raise exceptions.UserAlreadyExists()

CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

return user

Expand All @@ -320,24 +299,7 @@ async def oauth_callback(
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
# Get tenant_id from mapping table
if MULTI_TENANT:
try:
tenant_id = get_tenant_id_for_email(account_email)
except exceptions.UserNotExists:
# Tenant does not exist; provision a new tenant
tenant_provisioning_service = TenantProvisioningService()
try:
tenant_id = await tenant_provisioning_service.provision_tenant(
account_email
)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(
status_code=500, detail="Failed to provision tenant."
)
else:
tenant_id = POSTGRES_DEFAULT_SCHEMA
tenant_id = get_or_create_tenant_id(account_email)

if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
Expand Down
27 changes: 27 additions & 0 deletions backend/ee/danswer/server/tenants/provisioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from alembic import command
from alembic.config import Config
from danswer.auth.users import exceptions
from danswer.auth.users import get_tenant_id_for_email
from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from danswer.configs.app_configs import EXPECTED_API_KEY
from danswer.db.engine import build_connection_string
Expand All @@ -32,9 +34,34 @@
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.enums import EmbeddingProvider


logger = logging.getLogger(__name__)


async def get_or_create_tenant_id(email: str) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA

try:
tenant_id = get_tenant_id_for_email(email)
except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
tenant_provisioning_service = TenantProvisioningService()
try:
tenant_id = await tenant_provisioning_service.provision_tenant(email)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")

if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)

return tenant_id


class TenantProvisioningService:
async def provision_tenant(self, email: str) -> str:
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) # Generate new tenant ID
Expand Down

0 comments on commit a23004a

Please sign in to comment.