Skip to content

Commit

Permalink
various multi tenant improvements (#2803)
Browse files Browse the repository at this point in the history
* various multi tenant improvements

* nit

* ensure consistent db session operations

* minor robustification
  • Loading branch information
pablonyx authored Oct 15, 2024
1 parent 0e6c2f0 commit bfe9639
Show file tree
Hide file tree
Showing 15 changed files with 84 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
Expand All @@ -36,7 +36,7 @@ def check_for_connector_deletion_task(tenant_id: str | None) -> None:
if not lock_beat.acquire(blocking=False):
return

with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks(
Expand Down
6 changes: 3 additions & 3 deletions backend/danswer/background/celery/tasks/periodic/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from danswer.background.celery.celery_app import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
from danswer.db.engine import get_session_with_tenant


@shared_task(
Expand All @@ -23,7 +23,7 @@
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any) -> int:
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
"""Runs periodically to clean up the kombu_message table"""

# we will select messages older than this amount to clean up
Expand All @@ -35,7 +35,7 @@ def kombu_message_cleanup_task(self: Any) -> int:
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
Expand Down
9 changes: 5 additions & 4 deletions backend/danswer/background/celery/tasks/vespa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import DocumentSet
from danswer.db.models import UserGroup
Expand Down Expand Up @@ -341,7 +340,9 @@ def monitor_document_set_taskset(
r.delete(rds.fence_key)


def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
def monitor_connector_deletion_taskset(
key_bytes: bytes, r: Redis, tenant_id: str | None
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
Expand All @@ -367,7 +368,7 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
if count > 0:
return

with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
task_logger.warning(
Expand Down Expand Up @@ -529,7 +530,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:

lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r)
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)

with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _get_connector_runner(
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
Expand Down
3 changes: 3 additions & 0 deletions backend/danswer/configs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ class DocumentSource(str, Enum):
NOT_APPLICABLE = "not_applicable"


DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]


class NotificationType(str, Enum):
REINDEX = "reindex"

Expand Down
6 changes: 6 additions & 0 deletions backend/danswer/connectors/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sqlalchemy.orm import Session

from danswer.configs.constants import DocumentSource
from danswer.configs.constants import DocumentSourceRequiringTenantContext
from danswer.connectors.asana.connector import AsanaConnector
from danswer.connectors.axero.connector import AxeroConnector
from danswer.connectors.blob.connector import BlobStorageConnector
Expand Down Expand Up @@ -134,8 +135,13 @@ def instantiate_connector(
input_type: InputType,
connector_specific_config: dict[str, Any],
credential: Credential,
tenant_id: str | None = None,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)

if source in DocumentSourceRequiringTenantContext:
connector_specific_config["tenant_id"] = tenant_id

connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credential.credential_json)

Expand Down
7 changes: 5 additions & 2 deletions backend/danswer/connectors/file/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@

from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text
Expand Down Expand Up @@ -159,10 +160,12 @@ class LocalFileConnector(LoadConnector):
def __init__(
self,
file_locations: list[Path | str],
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [Path(file_location) for file_location in file_locations]
self.batch_size = batch_size
self.tenant_id = tenant_id
self.pdf_pass: str | None = None

def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
Expand All @@ -171,7 +174,7 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None

def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(self.tenant_id) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(
Expand Down
5 changes: 1 addition & 4 deletions backend/danswer/db/index_attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,14 +435,13 @@ def cancel_indexing_attempts_for_ccpair(

db_session.execute(stmt)

db_session.commit()


def cancel_indexing_attempts_past_model(
db_session: Session,
) -> None:
"""Stops all indexing attempts that are in progress or not started for
any embedding model that not present/future"""

db_session.execute(
update(IndexAttempt)
.where(
Expand All @@ -455,8 +454,6 @@ def cancel_indexing_attempts_past_model(
.values(status=IndexingStatus.FAILED)
)

db_session.commit()


def count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id: int | None,
Expand Down
4 changes: 3 additions & 1 deletion backend/danswer/server/documents/cc_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def update_cc_pair_status(
user=user,
get_editable=True,
)

if not cc_pair:
raise HTTPException(
status_code=400,
Expand All @@ -163,7 +164,6 @@ def update_cc_pair_status(
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)

# Just for good measure
cancel_indexing_attempts_past_model(db_session)

update_connector_credential_pair_from_id(
Expand All @@ -172,6 +172,8 @@ def update_cc_pair_status(
status=status_update_request.status,
)

db_session.commit()


@router.put("/admin/cc-pair/{cc_pair_id}/name")
def update_cc_pair_name(
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/server/manage/search_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def set_new_search_settings(
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)

db_session.commit()
return IdReturn(id=new_search_settings.id)


Expand Down
12 changes: 7 additions & 5 deletions backend/danswer/server/query_and_chat/token_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@

from danswer.auth.users import current_user
from danswer.db.engine import get_session_context_manager
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import TokenRateLimit
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
from shared_configs.configs import current_tenant_id


logger = setup_logger()
Expand All @@ -39,20 +41,20 @@ def check_token_rate_limits(
versioned_rate_limit_strategy = fetch_versioned_implementation(
"danswer.server.query_and_chat.token_limit", "_check_token_rate_limits"
)
return versioned_rate_limit_strategy(user)
return versioned_rate_limit_strategy(user, current_tenant_id.get())


def _check_token_rate_limits(_: User | None) -> None:
_user_is_rate_limited_by_global()
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
_user_is_rate_limited_by_global(tenant_id)


"""
Global rate limits
"""


def _user_is_rate_limited_by_global() -> None:
with get_session_context_manager() as db_session:
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
global_rate_limits = fetch_all_global_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
Expand Down
22 changes: 11 additions & 11 deletions backend/ee/danswer/server/query_and_chat/token_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlalchemy import select
from sqlalchemy.orm import Session

from danswer.db.engine import get_session_context_manager
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import TokenRateLimit
Expand All @@ -28,21 +28,21 @@
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits


def _check_token_rate_limits(user: User | None) -> None:
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
if user is None:
# Unauthenticated users are only rate limited by global settings
_user_is_rate_limited_by_global()
_user_is_rate_limited_by_global(tenant_id)

elif is_api_key_email_address(user.email):
# API keys are only rate limited by global settings
_user_is_rate_limited_by_global()
_user_is_rate_limited_by_global(tenant_id)

else:
run_functions_tuples_in_parallel(
[
(_user_is_rate_limited, (user.id,)),
(_user_is_rate_limited_by_group, (user.id,)),
(_user_is_rate_limited_by_global, ()),
(_user_is_rate_limited, (user.id, tenant_id)),
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
(_user_is_rate_limited_by_global, (tenant_id,)),
]
)

Expand All @@ -52,8 +52,8 @@ def _check_token_rate_limits(user: User | None) -> None:
"""


def _user_is_rate_limited(user_id: UUID) -> None:
with get_session_context_manager() as db_session:
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
user_rate_limits = fetch_all_user_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
Expand Down Expand Up @@ -93,8 +93,8 @@ def _fetch_user_usage(
"""


def _user_is_rate_limited_by_group(user_id: UUID) -> None:
with get_session_context_manager() as db_session:
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)

if group_rate_limits:
Expand Down
2 changes: 2 additions & 0 deletions backend/scripts/force_delete_connector_by_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
logger.notice(f"Deleting file {file_name}")
file_store.delete_file(file_name)

db_session.commit()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Delete a connector by its ID")
Expand Down
35 changes: 33 additions & 2 deletions web/src/app/auth/logout/route.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { CLOUD_ENABLED } from "@/lib/constants";
import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS";
import { NextRequest } from "next/server";

Expand All @@ -6,8 +7,38 @@ export const POST = async (request: NextRequest) => {
// Needed since env variables don't work well on the client-side
const authTypeMetadata = await getAuthTypeMetadataSS();
const response = await logoutSS(authTypeMetadata.authType, request.headers);
if (!response || response.ok) {

if (response && !response.ok) {
return new Response(response.body, { status: response?.status });
}

// Delete cookies only if cloud is enabled (jwt auth)
if (CLOUD_ENABLED) {
const cookiesToDelete = ["fastapiusersauth", "tenant_details"];
const cookieOptions = {
path: "/",
secure: process.env.NODE_ENV === "production",
httpOnly: true,
sameSite: "lax" as const,
};

// Logout successful, delete cookies
const headers = new Headers();

cookiesToDelete.forEach((cookieName) => {
headers.append(
"Set-Cookie",
`${cookieName}=; Max-Age=0; ${Object.entries(cookieOptions)
.map(([key, value]) => `${key}=${value}`)
.join("; ")}`
);
});

return new Response(null, {
status: 204,
headers: headers,
});
} else {
return new Response(null, { status: 204 });
}
return new Response(response.body, { status: response?.status });
};
1 change: 1 addition & 0 deletions web/src/lib/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ export const DISABLE_LLM_DOC_RELEVANCE =

export const CLOUD_ENABLED =
process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true";

export const REGISTRATION_URL =
process.env.INTERNAL_URL || "http://127.0.0.1:3001";

Expand Down

0 comments on commit bfe9639

Please sign in to comment.