From a5217ac8ce0e0e4914afb8b4398884ed92ffaa86 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 4 Nov 2024 16:16:30 -0800 Subject: [PATCH] update provisioning --- backend/danswer/auth/users.py | 7 +- backend/danswer/server/manage/users.py | 2 +- backend/ee/danswer/server/tenants/api.py | 54 ---- .../ee/danswer/server/tenants/provisioning.py | 254 ++++++------------ .../server/tenants/schema_management.py | 76 ++++++ .../ee/danswer/server/tenants/user_mapping.py | 50 ++++ 6 files changed, 206 insertions(+), 237 deletions(-) create mode 100644 backend/ee/danswer/server/tenants/schema_management.py create mode 100644 backend/ee/danswer/server/tenants/user_mapping.py diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 11387298a3b..6cd15362239 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -98,7 +98,6 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR - logger = setup_logger() @@ -239,7 +238,7 @@ async def create( safe: bool = False, request: Optional[Request] = None, ) -> User: - tenant_id = get_or_create_tenant_id(user_create.email) + tenant_id = await get_or_create_tenant_id(user_create.email) async with get_async_session_with_tenant(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) @@ -262,7 +261,7 @@ async def create( user_create.role = UserRole.BASIC try: - user = await super().create(user_create, safe=safe, request=request) + user = await super().create(user_create, safe=safe, request=request) # type: ignore except exceptions.UserAlreadyExists: user = await self.get_by_email(user_create.email) # Handle case where user has used product outside of web and is now creating an account through web @@ -299,7 +298,7 @@ async def oauth_callback( associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> models.UOAP: - tenant_id = get_or_create_tenant_id(account_email) + tenant_id = await get_or_create_tenant_id(account_email) if not tenant_id: raise HTTPException(status_code=401, detail="User not found") diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 59c4de89a71..701161ff3b7 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -66,7 +66,7 @@ from ee.danswer.db.user_group import remove_curator_status__no_commit from ee.danswer.server.tenants.billing import register_tenant_users from ee.danswer.server.tenants.provisioning import add_users_to_tenant -from ee.danswer.server.tenants.provisioning import remove_users_from_tenant +from ee.danswer.server.tenants.user_mapping import remove_users_from_tenant from shared_configs.configs import MULTI_TENANT logger = setup_logger() diff --git a/backend/ee/danswer/server/tenants/api.py b/backend/ee/danswer/server/tenants/api.py index 8e79c0b37b1..e6d0e048c83 100644 --- a/backend/ee/danswer/server/tenants/api.py +++ b/backend/ee/danswer/server/tenants/api.py @@ -15,7 +15,6 @@ from danswer.db.users import get_user_by_email from danswer.server.settings.store import load_settings from danswer.server.settings.store import store_settings -from danswer.setup import setup_danswer from danswer.utils.logger import setup_logger from ee.danswer.auth.users import current_cloud_superuser from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY @@ -23,15 +22,8 @@ from ee.danswer.server.tenants.billing import fetch_billing_information from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information from ee.danswer.server.tenants.models import BillingInformation -from ee.danswer.server.tenants.models import CreateTenantRequest from ee.danswer.server.tenants.models import ImpersonateRequest from ee.danswer.server.tenants.models import ProductGatingRequest -from ee.danswer.server.tenants.provisioning import add_users_to_tenant -from ee.danswer.server.tenants.provisioning import configure_default_api_keys -from ee.danswer.server.tenants.provisioning import ensure_schema_exists -from ee.danswer.server.tenants.provisioning import run_alembic_migrations -from ee.danswer.server.tenants.provisioning import user_owns_a_tenant -from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR stripe.api_key = STRIPE_SECRET_KEY @@ -40,52 +32,6 @@ router = APIRouter(prefix="/tenants") -@router.post("/create") -def create_tenant( - create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep) -) -> dict[str, str]: - if not MULTI_TENANT: - raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") - - tenant_id = create_tenant_request.tenant_id - email = create_tenant_request.initial_admin_email - token = None - - if user_owns_a_tenant(email): - raise HTTPException( - status_code=409, detail="User already belongs to an organization" - ) - - try: - if not ensure_schema_exists(tenant_id): - logger.info(f"Created schema for tenant {tenant_id}") - else: - logger.info(f"Schema already exists for tenant {tenant_id}") - - token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - run_alembic_migrations(tenant_id) - - with get_session_with_tenant(tenant_id) as db_session: - setup_danswer(db_session, tenant_id) - - configure_default_api_keys(db_session) - - add_users_to_tenant([email], tenant_id) - - return { - "status": "success", - "message": f"Tenant {tenant_id} created successfully", - } - except Exception as e: - logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}") - raise HTTPException( - status_code=500, detail=f"Failed to create tenant: {str(e)}" - ) - finally: - if token is not None: - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) - - @router.post("/product-gating") def gate_product( product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep) diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 7b6790cc2f1..0cb3821bc24 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -1,22 +1,15 @@ import asyncio import logging -import os import uuid -from types import SimpleNamespace import aiohttp # Async HTTP client from fastapi import HTTPException -from sqlalchemy import text from sqlalchemy.orm import Session -from sqlalchemy.schema import CreateSchema -from alembic import command -from alembic.config import Config from danswer.auth.users import exceptions from danswer.auth.users import get_tenant_id_for_email from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL from danswer.configs.app_configs import EXPECTED_API_KEY -from danswer.db.engine import build_connection_string from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_sqlalchemy_engine from danswer.db.llm import upsert_cloud_embedding_provider @@ -28,13 +21,17 @@ from ee.danswer.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY +from ee.danswer.server.tenants.schema_management import create_schema_if_not_exists +from ee.danswer.server.tenants.schema_management import drop_schema +from ee.danswer.server.tenants.schema_management import run_alembic_migrations +from ee.danswer.server.tenants.user_mapping import add_users_to_tenant +from ee.danswer.server.tenants.user_mapping import user_owns_a_tenant from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import TENANT_ID_PREFIX from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.enums import EmbeddingProvider - logger = logging.getLogger(__name__) @@ -47,9 +44,8 @@ async def get_or_create_tenant_id(email: str) -> str: tenant_id = get_tenant_id_for_email(email) except exceptions.UserNotExists: # If tenant does not exist and in Multi tenant mode, provision a new tenant - tenant_provisioning_service = TenantProvisioningService() try: - tenant_id = await tenant_provisioning_service.provision_tenant(email) + tenant_id = await create_tenant(email) except Exception as e: logger.error(f"Tenant provisioning failed: {e}") raise HTTPException(status_code=500, detail="Failed to provision tenant.") @@ -62,170 +58,94 @@ async def get_or_create_tenant_id(email: str) -> str: return tenant_id -class TenantProvisioningService: - async def provision_tenant(self, email: str) -> str: - tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) # Generate new tenant ID - +async def create_tenant(email: str) -> str: + tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) + try: # Provision tenant on data plane - await self._provision_on_data_plane(tenant_id, email) - + await provision_tenant(tenant_id, email) # Notify control plane - await self._notify_control_plane(tenant_id, email) - - return tenant_id - - async def _provision_on_data_plane(self, tenant_id: str, email: str) -> None: - if not MULTI_TENANT: - raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") - - if user_owns_a_tenant(email): - raise HTTPException( - status_code=409, detail="User already belongs to an organization" - ) - - logger.info(f"Provisioning tenant: {tenant_id}") - token = None - - try: - if not ensure_schema_exists(tenant_id): - logger.info(f"Created schema for tenant {tenant_id}") - else: - logger.info(f"Schema already exists for tenant {tenant_id}") + await notify_control_plane(tenant_id, email) + except Exception as e: + logger.error(f"Tenant provisioning failed: {e}") + await rollback_tenant_provisioning(tenant_id) + raise HTTPException(status_code=500, detail="Failed to provision tenant.") + return tenant_id - token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - # Await the Alembic migrations - await asyncio.to_thread(run_alembic_migrations, tenant_id) +async def provision_tenant(tenant_id: str, email: str) -> None: + if not MULTI_TENANT: + raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") - with get_session_with_tenant(tenant_id) as db_session: - setup_danswer(db_session, tenant_id) - configure_default_api_keys(db_session) + if user_owns_a_tenant(email): + raise HTTPException( + status_code=409, detail="User already belongs to an organization" + ) - add_users_to_tenant([email], tenant_id) + logger.info(f"Provisioning tenant: {tenant_id}") + token = None - except Exception as e: - logger.exception(f"Failed to create tenant {tenant_id}") - raise HTTPException( - status_code=500, detail=f"Failed to create tenant: {str(e)}" - ) - finally: - if token is not None: - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + try: + if not create_schema_if_not_exists(tenant_id): + logger.info(f"Created schema for tenant {tenant_id}") + else: + logger.info(f"Schema already exists for tenant {tenant_id}") - async def _notify_control_plane(self, tenant_id: str, email: str) -> None: - headers = { - "Authorization": f"Bearer {EXPECTED_API_KEY}", # Replace with your control plane API key - "Content-Type": "application/json", - } - payload = {"tenant_id": tenant_id, "email": email} + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - async with aiohttp.ClientSession() as session: - async with session.post( - f"{CONTROL_PLANE_API_BASE_URL}/tenants/create", # Replace with your control plane URL - headers=headers, - json=payload, - ) as response: - if response.status != 200: - error_text = await response.text() - logger.error(f"Control plane tenant creation failed: {error_text}") - raise Exception( - f"Failed to create tenant on control plane: {error_text}" - ) + # Await the Alembic migrations + await asyncio.to_thread(run_alembic_migrations, tenant_id) - async def rollback_tenant_provisioning(self, tenant_id: str) -> None: - # Logic to rollback tenant provisioning on data plane - logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}") - try: - # Drop the tenant's schema to rollback provisioning - drop_schema(tenant_id) - # Remove tenant mapping - with Session(get_sqlalchemy_engine()) as db_session: - db_session.query(UserTenantMapping).filter( - UserTenantMapping.tenant_id == tenant_id - ).delete() - db_session.commit() - except Exception as e: - logger.error(f"Failed to rollback tenant provisioning: {e}") + with get_session_with_tenant(tenant_id) as db_session: + setup_danswer(db_session, tenant_id) + configure_default_api_keys(db_session) + add_users_to_tenant([email], tenant_id) -# For now, we're implementing a primitive mapping between users and tenants. -# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership). -def user_owns_a_tenant(email: str) -> bool: - with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: - result = ( - db_session.query(UserTenantMapping) - .filter(UserTenantMapping.email == email) - .first() + except Exception as e: + logger.exception(f"Failed to create tenant {tenant_id}") + raise HTTPException( + status_code=500, detail=f"Failed to create tenant: {str(e)}" ) - return result is not None - - -def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: - with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: - try: - for email in emails: - db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id)) - except Exception: - logger.exception(f"Failed to add users to tenant {tenant_id}") - db_session.commit() - - -def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: - with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: - try: - mappings_to_delete = ( - db_session.query(UserTenantMapping) - .filter( - UserTenantMapping.email.in_(emails), - UserTenantMapping.tenant_id == tenant_id, + finally: + if token is not None: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + +async def notify_control_plane(tenant_id: str, email: str) -> None: + headers = { + "Authorization": f"Bearer {EXPECTED_API_KEY}", + "Content-Type": "application/json", + } + payload = {"tenant_id": tenant_id, "email": email} + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{CONTROL_PLANE_API_BASE_URL}/tenants/create", + headers=headers, + json=payload, + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Control plane tenant creation failed: {error_text}") + raise Exception( + f"Failed to create tenant on control plane: {error_text}" ) - .all() - ) - for mapping in mappings_to_delete: - db_session.delete(mapping) - - db_session.commit() - except Exception as e: - logger.exception( - f"Failed to remove users from tenant {tenant_id}: {str(e)}" - ) - db_session.rollback() - - -def run_alembic_migrations(schema_name: str) -> None: - logger.info(f"Starting Alembic migrations for schema: {schema_name}") +async def rollback_tenant_provisioning(tenant_id: str) -> None: + # Logic to rollback tenant provisioning on data plane + logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}") try: - current_dir = os.path.dirname(os.path.abspath(__file__)) - root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..")) - alembic_ini_path = os.path.join(root_dir, "alembic.ini") - - # Configure Alembic - alembic_cfg = Config(alembic_ini_path) - alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string()) - alembic_cfg.set_main_option( - "script_location", os.path.join(root_dir, "alembic") - ) - - # Ensure that logging isn't broken - alembic_cfg.attributes["configure_logger"] = False - - # Mimic command-line options by adding 'cmd_opts' to the config - alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore - alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore - - # Run migrations programmatically - command.upgrade(alembic_cfg, "head") - - # Run migrations programmatically - logger.info( - f"Alembic migrations completed successfully for schema: {schema_name}" - ) - + # Drop the tenant's schema to rollback provisioning + drop_schema(tenant_id) + # Remove tenant mapping + with Session(get_sqlalchemy_engine()) as db_session: + db_session.query(UserTenantMapping).filter( + UserTenantMapping.tenant_id == tenant_id + ).delete() + db_session.commit() except Exception as e: - logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}") - raise + logger.error(f"Failed to rollback tenant provisioning: {e}") def configure_default_api_keys(db_session: Session) -> None: @@ -249,25 +169,3 @@ def configure_default_api_keys(db_session: Session) -> None: api_key=COHERE_DEFAULT_API_KEY, ) upsert_cloud_embedding_provider(db_session, cloud_embedding_provider) - - -def ensure_schema_exists(tenant_id: str) -> bool: - with Session(get_sqlalchemy_engine()) as db_session: - with db_session.begin(): - result = db_session.execute( - text( - "SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name" - ), - {"schema_name": tenant_id}, - ) - schema_exists = result.scalar() is not None - if not schema_exists: - stmt = CreateSchema(tenant_id) - db_session.execute(stmt) - return True - return False - - -def drop_schema(tenant_id: str) -> None: - with get_sqlalchemy_engine().connect() as connection: - connection.execute(text(f"DROP SCHEMA IF EXISTS {tenant_id} CASCADE")) diff --git a/backend/ee/danswer/server/tenants/schema_management.py b/backend/ee/danswer/server/tenants/schema_management.py new file mode 100644 index 00000000000..9be4e79f984 --- /dev/null +++ b/backend/ee/danswer/server/tenants/schema_management.py @@ -0,0 +1,76 @@ +import logging +import os +from types import SimpleNamespace + +from sqlalchemy import text +from sqlalchemy.orm import Session +from sqlalchemy.schema import CreateSchema + +from alembic import command +from alembic.config import Config +from danswer.db.engine import build_connection_string +from danswer.db.engine import get_sqlalchemy_engine + +logger = logging.getLogger(__name__) + + +def run_alembic_migrations(schema_name: str) -> None: + logger.info(f"Starting Alembic migrations for schema: {schema_name}") + + try: + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..")) + alembic_ini_path = os.path.join(root_dir, "alembic.ini") + + # Configure Alembic + alembic_cfg = Config(alembic_ini_path) + alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string()) + alembic_cfg.set_main_option( + "script_location", os.path.join(root_dir, "alembic") + ) + + # Ensure that logging isn't broken + alembic_cfg.attributes["configure_logger"] = False + + # Mimic command-line options by adding 'cmd_opts' to the config + alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore + alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore + + # Run migrations programmatically + command.upgrade(alembic_cfg, "head") + + # Run migrations programmatically + logger.info( + f"Alembic migrations completed successfully for schema: {schema_name}" + ) + + except Exception as e: + logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}") + raise + + +def create_schema_if_not_exists(tenant_id: str) -> bool: + with Session(get_sqlalchemy_engine()) as db_session: + with db_session.begin(): + result = db_session.execute( + text( + "SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name" + ), + {"schema_name": tenant_id}, + ) + schema_exists = result.scalar() is not None + if not schema_exists: + stmt = CreateSchema(tenant_id) + db_session.execute(stmt) + return True + return False + + +def drop_schema(tenant_id: str) -> None: + if not tenant_id.isidentifier(): + raise ValueError("Invalid tenant_id.") + with get_sqlalchemy_engine().connect() as connection: + connection.execute( + text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"), + {"schema_name": tenant_id}, + ) diff --git a/backend/ee/danswer/server/tenants/user_mapping.py b/backend/ee/danswer/server/tenants/user_mapping.py new file mode 100644 index 00000000000..3a3e9befc59 --- /dev/null +++ b/backend/ee/danswer/server/tenants/user_mapping.py @@ -0,0 +1,50 @@ +import logging + +from danswer.db.engine import get_session_with_tenant +from danswer.db.models import UserTenantMapping +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA + +logger = logging.getLogger(__name__) + + +def user_owns_a_tenant(email: str) -> bool: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + result = ( + db_session.query(UserTenantMapping) + .filter(UserTenantMapping.email == email) + .first() + ) + return result is not None + + +def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + try: + for email in emails: + db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id)) + except Exception: + logger.exception(f"Failed to add users to tenant {tenant_id}") + db_session.commit() + + +def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + try: + mappings_to_delete = ( + db_session.query(UserTenantMapping) + .filter( + UserTenantMapping.email.in_(emails), + UserTenantMapping.tenant_id == tenant_id, + ) + .all() + ) + + for mapping in mappings_to_delete: + db_session.delete(mapping) + + db_session.commit() + except Exception as e: + logger.exception( + f"Failed to remove users from tenant {tenant_id}: {str(e)}" + ) + db_session.rollback()