Skip to content

Commit

Permalink
feat: Start work on tests for LiteralDataLayer.
Browse files Browse the repository at this point in the history
  • Loading branch information
dokterbob committed Sep 27, 2024
1 parent bcac555 commit c9bbd7d
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 11 deletions.
15 changes: 9 additions & 6 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, Mock

Expand All @@ -10,10 +11,12 @@


@pytest.fixture
def mock_persisted_user():
mock = Mock(spec=PersistedUser)
mock.id = "test_user_id"
return mock
def persisted_test_user():
return PersistedUser(
id="test_user_id",
createdAt=datetime.datetime.now().isoformat(),
identifier="test_user_identifier",
)


@pytest.fixture
Expand Down Expand Up @@ -44,8 +47,8 @@ async def create_chainlit_context(mock_session):


@pytest_asyncio.fixture
async def mock_chainlit_context(mock_persisted_user, mock_session):
mock_session.user = mock_persisted_user
async def mock_chainlit_context(persisted_test_user, mock_session):
mock_session.user = persisted_test_user
return create_chainlit_context(mock_session)


Expand Down
6 changes: 6 additions & 0 deletions backend/tests/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import AsyncMock

from chainlit.data.base import BaseStorageClient
from chainlit.user import User


@pytest.fixture
Expand All @@ -13,3 +14,8 @@ def mock_storage_client():
"object_key": "test_user/test_element/test.txt",
}
return mock_client


@pytest.fixture
def test_user() -> User:
return User(identifier="test_user_identifier", metadata={})
244 changes: 244 additions & 0 deletions backend/tests/data/test_literalai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import datetime
import uuid
from unittest.mock import AsyncMock, Mock

import pytest
from chainlit.data.literalai import LiteralDataLayer
from chainlit.types import ThreadDict
from chainlit.user import PersistedUser, User
from literalai import AsyncLiteralClient
from literalai import Thread as LiteralThread
from literalai import User as LiteralUser
from literalai import Step as LiteralStep
from literalai.api import AsyncLiteralAPI


@pytest.fixture
async def mock_literal_client(monkeypatch: pytest.MonkeyPatch):
client = Mock(spec=AsyncLiteralClient)
client.api = AsyncMock(spec=AsyncLiteralAPI)
monkeypatch.setattr("literalai.AsyncLiteralClient", client)
yield client


@pytest.fixture
async def literal_data_layer(mock_literal_client):
data_layer = LiteralDataLayer(api_key="fake_api_key", server="https://fake.server")
data_layer.client = mock_literal_client
return data_layer


@pytest.fixture
def test_thread():
return LiteralThread.from_dict(
{
"id": "test_thread_id",
"name": "Test Thread",
"createdAt": "2023-01-01T00:00:00Z",
"metadata": {},
"participant": {},
"steps": [],
"tags": [],
}
)


@pytest.fixture
def test_step(test_thread: LiteralThread):
return LiteralStep.from_dict(
{
"id": str(uuid.uuid4()),
"name": "Test Step",
"type": "user_message",
"environment": None,
"threadId": test_thread.id,
"error": None,
"input": {},
"output": {},
"metadata": {},
"tags": [],
"parentId": None,
"createdAt": "2023-01-01T00:00:00Z",
"startTime": "2023-01-01T00:00:00Z",
"endTime": "2023-01-01T00:00:00Z",
"generation": {},
"scores": [],
"attachments": [],
"rootRunId": None,
}
)


@pytest.fixture
def literal_test_user(test_user: User):
return LiteralUser(
id=str(uuid.uuid4()),
created_at=datetime.datetime.now().isoformat(),
identifier=test_user.identifier,
metadata=test_user.metadata,
)


async def test_get_user(
literal_data_layer: LiteralDataLayer,
mock_literal_client: Mock,
literal_test_user: LiteralUser,
persisted_test_user: PersistedUser,
):
mock_literal_client.api.get_user.return_value = literal_test_user

user = await literal_data_layer.get_user("test_user_id")

assert user is not None
assert user.id == literal_test_user.id
assert user.identifier == literal_test_user.identifier

mock_literal_client.api.get_user.assert_awaited_once_with(identifier="test_user_id")


async def test_get_user_not_found(
literal_data_layer: LiteralDataLayer, mock_literal_client: Mock
):
mock_literal_client.api.get_user.return_value = None

user = await literal_data_layer.get_user("non_existent_user_id")

assert user is None
mock_literal_client.api.get_user.assert_awaited_once_with(
identifier="non_existent_user_id"
)


async def test_create_user_not_existing(
literal_data_layer: LiteralDataLayer,
mock_literal_client: Mock,
test_user: User,
literal_test_user: LiteralUser,
):
mock_literal_client.api.get_user.return_value = None
mock_literal_client.api.create_user.return_value = literal_test_user

persisted_user = await literal_data_layer.create_user(test_user)

mock_literal_client.api.create_user.assert_awaited_once_with(
identifier=test_user.identifier, metadata=test_user.metadata
)

assert persisted_user is not None
assert isinstance(persisted_user, PersistedUser)
assert persisted_user.id == literal_test_user.id
assert persisted_user.identifier == literal_test_user.identifier


async def test_create_user_update_existing(
literal_data_layer: LiteralDataLayer,
mock_literal_client: Mock,
test_user: User,
literal_test_user: LiteralUser,
persisted_test_user: PersistedUser,
):
mock_literal_client.api.get_user.return_value = literal_test_user

persisted_user = await literal_data_layer.create_user(test_user)

mock_literal_client.api.create_user.assert_not_called()
mock_literal_client.api.update_user.assert_awaited_once_with(
id=literal_test_user.id, metadata=test_user.metadata
)

assert persisted_user is not None
assert isinstance(persisted_user, PersistedUser)
assert persisted_user.id == literal_test_user.id
assert persisted_user.identifier == literal_test_user.identifier


async def test_create_user_id_none(
literal_data_layer: LiteralDataLayer,
mock_literal_client: Mock,
test_user: User,
literal_test_user: LiteralUser,
):
"""Weird edge case; persisted user without an id. Do we need this!??"""

literal_test_user.id = None
mock_literal_client.api.get_user.return_value = literal_test_user

persisted_user = await literal_data_layer.create_user(test_user)

mock_literal_client.api.create_user.assert_not_called()
mock_literal_client.api.update_user.assert_not_called()

assert persisted_user is not None
assert isinstance(persisted_user, PersistedUser)
assert persisted_user.id == ""
assert persisted_user.identifier == literal_test_user.identifier


async def test_update_thread(
literal_data_layer: LiteralDataLayer,
mock_literal_client: Mock,
test_thread: LiteralThread,
):
await literal_data_layer.update_thread(test_thread.id, name=test_thread.name)

mock_literal_client.api.upsert_thread.assert_awaited_once_with(
id=test_thread.id,
name=test_thread.name,
participant_id=None,
metadata=None,
tags=None,
)


async def test_get_thread_author(
literal_data_layer, mock_literal_client: Mock, test_thread: LiteralThread
):
test_thread.participant_identifier = "test_user_identifier"
mock_literal_client.api.get_thread.return_value = test_thread

author = await literal_data_layer.get_thread_author(test_thread.id)

assert author == "test_user_identifier"
mock_literal_client.api.get_thread.assert_awaited_once_with(id=test_thread.id)


async def test_get_thread(
literal_data_layer: LiteralDataLayer,
mock_literal_client: Mock,
test_thread: LiteralThread,
test_step: LiteralStep,
):
assert isinstance(test_thread.steps, list)
test_thread.steps.append(test_step)
mock_literal_client.api.get_thread.return_value = test_thread

thread = await literal_data_layer.get_thread(test_thread.id)
mock_literal_client.api.get_thread.assert_awaited_once_with(id=test_thread.id)

assert thread is not None
assert thread["id"] == test_thread.id
assert thread["name"] == test_thread.name
assert thread["steps"]


async def test_get_thread_non_existing(
literal_data_layer: LiteralDataLayer, mock_literal_client: Mock
):
mock_literal_client.api.get_thread.return_value = None

thread = await literal_data_layer.get_thread("non_existent_thread_id")
mock_literal_client.api.get_thread.assert_awaited_once_with(
id="non_existent_thread_id"
)

assert thread is None


async def test_delete_thread(
literal_data_layer: LiteralDataLayer,
mock_literal_client: Mock,
test_thread: LiteralThread,
):
await literal_data_layer.delete_thread(test_thread.id)

mock_literal_client.api.delete_thread.assert_awaited_once_with(id=test_thread.id)
5 changes: 0 additions & 5 deletions backend/tests/data/test_sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,6 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
yield data_layer


@pytest.fixture
def test_user() -> User:
return User(identifier="sqlalchemy_test_user_id")


async def test_create_and_get_element(
mock_chainlit_context, data_layer: SQLAlchemyDataLayer
):
Expand Down

0 comments on commit c9bbd7d

Please sign in to comment.