Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper tenant reset #3015

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions backend/danswer/db/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,18 @@ async def get_async_session_with_tenant(
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
"""
Generate a database session bound to a connection with the appropriate tenant schema set.
This preserves the tenant ID across the session and reverts to the previous tenant ID
after the session is closed.
"""
engine = get_sqlalchemy_engine()

# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()

if tenant_id is None:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
tenant_id = previous_tenant_id
else:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)

Expand All @@ -335,30 +342,35 @@ def get_session_with_tenant(
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")

# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection

# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
cursor.close()
try:
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection

# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
yield session
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
cursor.close()

# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()

finally:
# Restore the previous tenant ID
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)


def set_search_path_on_checkout(
Expand Down
4 changes: 3 additions & 1 deletion backend/danswer/server/manage/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def bulk_invite_users(
)

tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()

normalized_emails = []
try:
for email in emails:
Expand All @@ -206,13 +205,16 @@ def bulk_invite_users(
if MULTI_TENANT:
try:
add_users_to_tenant(normalized_emails, tenant_id)

except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise HTTPException(
status_code=400,
detail="User has already been invited to a Danswer organization",
)
raise
except Exception as e:
logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}")

initial_invited_users = get_invited_users()

Expand Down
Loading