From 0b08bf4e3f0a8133ca8fc21d734d4693a6fb58ae Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 31 Oct 2024 12:45:35 -0700 Subject: [PATCH] Proper tenant reset (#3015) * add proper tenant reset * clear comment * minor formatting --- backend/danswer/db/engine.py | 58 ++++++++++++++++---------- backend/danswer/server/manage/users.py | 4 +- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 7336493e0a0..a1fbbddc65f 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -322,11 +322,18 @@ async def get_async_session_with_tenant( def get_session_with_tenant( tenant_id: str | None = None, ) -> Generator[Session, None, None]: - """Generate a database session bound to a connection with the appropriate tenant schema set.""" + """ + Generate a database session bound to a connection with the appropriate tenant schema set. + This preserves the tenant ID across the session and reverts to the previous tenant ID + after the session is closed. + """ engine = get_sqlalchemy_engine() + # Store the previous tenant ID + previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() + if tenant_id is None: - tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() + tenant_id = previous_tenant_id else: CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) @@ -335,30 +342,35 @@ def get_session_with_tenant( if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID") - # Establish a raw connection - with engine.connect() as connection: - # Access the raw DBAPI connection and set the search_path - dbapi_connection = connection.connection - - # Set the search_path outside of any transaction - cursor = dbapi_connection.cursor() - try: - cursor.execute(f'SET search_path = "{tenant_id}"') - finally: - cursor.close() + try: + # Establish a raw connection + with engine.connect() as connection: + # Access the raw DBAPI connection and set the search_path + dbapi_connection = connection.connection - # Bind the session to the connection - with Session(bind=connection, expire_on_commit=False) as session: + # Set the search_path outside of any transaction + cursor = dbapi_connection.cursor() try: - yield session + cursor.execute(f'SET search_path = "{tenant_id}"') finally: - # Reset search_path to default after the session is used - if MULTI_TENANT: - cursor = dbapi_connection.cursor() - try: - cursor.execute('SET search_path TO "$user", public') - finally: - cursor.close() + cursor.close() + + # Bind the session to the connection + with Session(bind=connection, expire_on_commit=False) as session: + try: + yield session + finally: + # Reset search_path to default after the session is used + if MULTI_TENANT: + cursor = dbapi_connection.cursor() + try: + cursor.execute('SET search_path TO "$user", public') + finally: + cursor.close() + + finally: + # Restore the previous tenant ID + CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id) def set_search_path_on_checkout( diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index bbaf08b7db9..9cb1cec1c1a 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -190,7 +190,6 @@ def bulk_invite_users( ) tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() - normalized_emails = [] try: for email in emails: @@ -206,6 +205,7 @@ def bulk_invite_users( if MULTI_TENANT: try: add_users_to_tenant(normalized_emails, tenant_id) + except IntegrityError as e: if isinstance(e.orig, UniqueViolation): raise HTTPException( @@ -213,6 +213,8 @@ def bulk_invite_users( detail="User has already been invited to a Danswer organization", ) raise + except Exception as e: + logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}") initial_invited_users = get_invited_users()