Skip to content

Commit

Permalink
rename classes and ignore deprecation warnings we mostly don't have c… (
Browse files Browse the repository at this point in the history
#2546)

* rename classes and ignore deprecation warnings we mostly don't have control over

* copy pytest.ini

* ignore CryptographyDeprecationWarning

* fully qualify the warning
  • Loading branch information
rkuo-danswer authored Sep 24, 2024
1 parent cb75449 commit c8d1392
Show file tree
Hide file tree
Showing 30 changed files with 242 additions and 232 deletions.
6 changes: 5 additions & 1 deletion backend/pytest.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
[pytest]
pythonpath = .
markers =
slow: marks tests as slow
slow: marks tests as slow
filterwarnings =
ignore::DeprecationWarning
ignore::cryptography.utils.CryptographyDeprecationWarning

1 change: 1 addition & 0 deletions backend/tests/integration/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ COPY ./danswer /app/danswer
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic.ini /app/alembic.ini
COPY ./pytest.ini /app/pytest.ini
COPY supervisord.conf /usr/etc/supervisord.conf

# Integration test stuff
Expand Down
14 changes: 7 additions & 7 deletions backend/tests/integration/common_utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestLLMProvider
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser


class LLMProviderManager:
Expand All @@ -21,8 +21,8 @@ def create(
api_version: str | None = None,
groups: list[int] | None = None,
is_public: bool | None = None,
user_performing_action: TestUser | None = None,
) -> TestLLMProvider:
user_performing_action: DATestUser | None = None,
) -> DATestLLMProvider:
print("Seeding LLM Providers...")

llm_provider = LLMProviderUpsertRequest(
Expand All @@ -49,7 +49,7 @@ def create(
)
llm_response.raise_for_status()
response_data = llm_response.json()
result_llm = TestLLMProvider(
result_llm = DATestLLMProvider(
id=response_data["id"],
name=response_data["name"],
provider=response_data["provider"],
Expand All @@ -73,8 +73,8 @@ def create(

@staticmethod
def delete(
llm_provider: TestLLMProvider,
user_performing_action: TestUser | None = None,
llm_provider: DATestLLMProvider,
user_performing_action: DATestUser | None = None,
) -> bool:
if not llm_provider.id:
raise ValueError("LLM Provider ID is required to delete a provider")
Expand Down
24 changes: 12 additions & 12 deletions backend/tests/integration/common_utils/managers/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
from ee.danswer.server.api_key.models import APIKeyArgs
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestAPIKey
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestUser


class APIKeyManager:
@staticmethod
def create(
name: str | None = None,
api_key_role: UserRole = UserRole.ADMIN,
user_performing_action: TestUser | None = None,
) -> TestAPIKey:
user_performing_action: DATestUser | None = None,
) -> DATestAPIKey:
name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}"
api_key_request = APIKeyArgs(
name=name,
Expand All @@ -31,7 +31,7 @@ def create(
)
api_key_response.raise_for_status()
api_key = api_key_response.json()
result_api_key = TestAPIKey(
result_api_key = DATestAPIKey(
api_key_id=api_key["api_key_id"],
api_key_display=api_key["api_key_display"],
api_key=api_key["api_key"],
Expand All @@ -45,8 +45,8 @@ def create(

@staticmethod
def delete(
api_key: TestAPIKey,
user_performing_action: TestUser | None = None,
api_key: DATestAPIKey,
user_performing_action: DATestUser | None = None,
) -> None:
api_key_response = requests.delete(
f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}",
Expand All @@ -58,22 +58,22 @@ def delete(

@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
) -> list[TestAPIKey]:
user_performing_action: DATestUser | None = None,
) -> list[DATestAPIKey]:
api_key_response = requests.get(
f"{API_SERVER_URL}/admin/api-key",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
api_key_response.raise_for_status()
return [TestAPIKey(**api_key) for api_key in api_key_response.json()]
return [DATestAPIKey(**api_key) for api_key in api_key_response.json()]

@staticmethod
def verify(
api_key: TestAPIKey,
api_key: DATestAPIKey,
verify_deleted: bool = False,
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_keys = APIKeyManager.get_all(
user_performing_action=user_performing_action
Expand Down
34 changes: 17 additions & 17 deletions backend/tests/integration/common_utils/managers/cc_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.test_models import TestCCPair
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser


def _cc_pair_creator(
Expand All @@ -25,8 +25,8 @@ def _cc_pair_creator(
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCCPair:
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}"

request = {
Expand All @@ -43,7 +43,7 @@ def _cc_pair_creator(
else GENERAL_HEADERS,
)
response.raise_for_status()
return TestCCPair(
return DATestCCPair(
id=response.json()["data"],
name=name,
connector_id=connector_id,
Expand All @@ -63,8 +63,8 @@ def create_from_scratch(
input_type: InputType = InputType.LOAD_STATE,
connector_specific_config: dict[str, Any] | None = None,
credential_json: dict[str, Any] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCCPair:
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
connector = ConnectorManager.create(
name=name,
source=source,
Expand Down Expand Up @@ -98,8 +98,8 @@ def create(
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCCPair:
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
return _cc_pair_creator(
connector_id=connector_id,
credential_id=credential_id,
Expand All @@ -111,8 +111,8 @@ def create(

@staticmethod
def pause_cc_pair(
cc_pair: TestCCPair,
user_performing_action: TestUser | None = None,
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.put(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
Expand All @@ -125,8 +125,8 @@ def pause_cc_pair(

@staticmethod
def delete(
cc_pair: TestCCPair,
user_performing_action: TestUser | None = None,
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
cc_pair_identifier = ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
Expand All @@ -143,7 +143,7 @@ def delete(

@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> list[ConnectorIndexingStatus]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
Expand All @@ -156,9 +156,9 @@ def get_all(

@staticmethod
def verify(
cc_pair: TestCCPair,
cc_pair: DATestCCPair,
verify_deleted: bool = False,
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
all_cc_pairs = CCPairManager.get_all(user_performing_action)
for retrieved_cc_pair in all_cc_pairs:
Expand All @@ -183,7 +183,7 @@ def verify(

@staticmethod
def wait_for_deletion_completion(
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
start = time.time()
while True:
Expand Down
24 changes: 12 additions & 12 deletions backend/tests/integration/common_utils/managers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestChatMessage
from tests.integration.common_utils.test_models import DATestChatSession
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import StreamedResponse
from tests.integration.common_utils.test_models import TestChatMessage
from tests.integration.common_utils.test_models import TestChatSession
from tests.integration.common_utils.test_models import TestUser


class ChatSessionManager:
@staticmethod
def create(
persona_id: int = -1,
description: str = "Test chat session",
user_performing_action: TestUser | None = None,
) -> TestChatSession:
user_performing_action: DATestUser | None = None,
) -> DATestChatSession:
chat_session_creation_req = ChatSessionCreationRequest(
persona_id=persona_id, description=description
)
Expand All @@ -38,7 +38,7 @@ def create(
)
response.raise_for_status()
chat_session_id = response.json()["chat_session_id"]
return TestChatSession(
return DATestChatSession(
id=chat_session_id, persona_id=persona_id, description=description
)

Expand All @@ -47,7 +47,7 @@ def send_message(
chat_session_id: int,
message: str,
parent_message_id: int | None = None,
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
file_descriptors: list[FileDescriptor] = [],
prompt_id: int | None = None,
search_doc_ids: list[int] | None = None,
Expand Down Expand Up @@ -90,7 +90,7 @@ def send_message(
def get_answer_with_quote(
persona_id: int,
message: str,
user_performing_action: TestUser | None = None,
user_performing_action: DATestUser | None = None,
) -> StreamedResponse:
direct_qa_request = DirectQARequest(
messages=[ThreadMessage(message=message)],
Expand Down Expand Up @@ -137,9 +137,9 @@ def analyze_response(response: Response) -> StreamedResponse:

@staticmethod
def get_chat_history(
chat_session: TestChatSession,
user_performing_action: TestUser | None = None,
) -> list[TestChatMessage]:
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
) -> list[DATestChatMessage]:
response = requests.get(
f"{API_SERVER_URL}/chat/history/{chat_session.id}",
headers=user_performing_action.headers
Expand All @@ -149,7 +149,7 @@ def get_chat_history(
response.raise_for_status()

return [
TestChatMessage(
DATestChatMessage(
id=msg["id"],
chat_session_id=chat_session.id,
parent_message_id=msg.get("parent_message_id"),
Expand Down
Loading

0 comments on commit c8d1392

Please sign in to comment.