diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index afe369f7c37..11387298a3b 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -93,7 +93,7 @@ from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation -from ee.danswer.server.tenants.provisioning import TenantProvisioningService +from ee.danswer.server.tenants.provisioning import get_or_create_tenant_id from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR @@ -239,29 +239,7 @@ async def create( safe: bool = False, request: Optional[Request] = None, ) -> User: - if MULTI_TENANT: - try: - tenant_id = get_tenant_id_for_email(user_create.email) - - except exceptions.UserNotExists: - # If tenant does not exist and in Multi tenant mode, provision a new tenant - tenant_provisioning_service = TenantProvisioningService() - try: - tenant_id = await tenant_provisioning_service.provision_tenant( - user_create.email - ) - except Exception as e: - logger.error(f"Tenant provisioning failed: {e}") - raise HTTPException( - status_code=500, detail="Failed to provision tenant." - ) - - if not tenant_id: - raise HTTPException( - status_code=401, detail="User does not belong to an organization" - ) - else: - tenant_id = POSTGRES_DEFAULT_SCHEMA + tenant_id = get_or_create_tenant_id(user_create.email) async with get_async_session_with_tenant(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) @@ -303,7 +281,8 @@ async def create( else: raise exceptions.UserAlreadyExists() - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + finally: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) return user @@ -320,24 +299,7 @@ async def oauth_callback( associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> models.UOAP: - # Get tenant_id from mapping table - if MULTI_TENANT: - try: - tenant_id = get_tenant_id_for_email(account_email) - except exceptions.UserNotExists: - # Tenant does not exist; provision a new tenant - tenant_provisioning_service = TenantProvisioningService() - try: - tenant_id = await tenant_provisioning_service.provision_tenant( - account_email - ) - except Exception as e: - logger.error(f"Tenant provisioning failed: {e}") - raise HTTPException( - status_code=500, detail="Failed to provision tenant." - ) - else: - tenant_id = POSTGRES_DEFAULT_SCHEMA + tenant_id = get_or_create_tenant_id(account_email) if not tenant_id: raise HTTPException(status_code=401, detail="User not found") diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 4f9189f476c..7b6790cc2f1 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -12,6 +12,8 @@ from alembic import command from alembic.config import Config +from danswer.auth.users import exceptions +from danswer.auth.users import get_tenant_id_for_email from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL from danswer.configs.app_configs import EXPECTED_API_KEY from danswer.db.engine import build_connection_string @@ -32,9 +34,34 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.enums import EmbeddingProvider + logger = logging.getLogger(__name__) +async def get_or_create_tenant_id(email: str) -> str: + """Get existing tenant ID for an email or create a new tenant if none exists.""" + if not MULTI_TENANT: + return POSTGRES_DEFAULT_SCHEMA + + try: + tenant_id = get_tenant_id_for_email(email) + except exceptions.UserNotExists: + # If tenant does not exist and in Multi tenant mode, provision a new tenant + tenant_provisioning_service = TenantProvisioningService() + try: + tenant_id = await tenant_provisioning_service.provision_tenant(email) + except Exception as e: + logger.error(f"Tenant provisioning failed: {e}") + raise HTTPException(status_code=500, detail="Failed to provision tenant.") + + if not tenant_id: + raise HTTPException( + status_code=401, detail="User does not belong to an organization" + ) + + return tenant_id + + class TenantProvisioningService: async def provision_tenant(self, email: str) -> str: tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) # Generate new tenant ID