-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Start work on tests for LiteralDataLayer.
- Loading branch information
Showing
4 changed files
with
259 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
import datetime | ||
import uuid | ||
from unittest.mock import AsyncMock, Mock | ||
|
||
import pytest | ||
from chainlit.data.literalai import LiteralDataLayer | ||
from chainlit.types import ThreadDict | ||
from chainlit.user import PersistedUser, User | ||
from literalai import AsyncLiteralClient | ||
from literalai import Thread as LiteralThread | ||
from literalai import User as LiteralUser | ||
from literalai import Step as LiteralStep | ||
from literalai.api import AsyncLiteralAPI | ||
|
||
|
||
@pytest.fixture | ||
async def mock_literal_client(monkeypatch: pytest.MonkeyPatch): | ||
client = Mock(spec=AsyncLiteralClient) | ||
client.api = AsyncMock(spec=AsyncLiteralAPI) | ||
monkeypatch.setattr("literalai.AsyncLiteralClient", client) | ||
yield client | ||
|
||
|
||
@pytest.fixture | ||
async def literal_data_layer(mock_literal_client): | ||
data_layer = LiteralDataLayer(api_key="fake_api_key", server="https://fake.server") | ||
data_layer.client = mock_literal_client | ||
return data_layer | ||
|
||
|
||
@pytest.fixture | ||
def test_thread(): | ||
return LiteralThread.from_dict( | ||
{ | ||
"id": "test_thread_id", | ||
"name": "Test Thread", | ||
"createdAt": "2023-01-01T00:00:00Z", | ||
"metadata": {}, | ||
"participant": {}, | ||
"steps": [], | ||
"tags": [], | ||
} | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def test_step(test_thread: LiteralThread): | ||
return LiteralStep.from_dict( | ||
{ | ||
"id": str(uuid.uuid4()), | ||
"name": "Test Step", | ||
"type": "user_message", | ||
"environment": None, | ||
"threadId": test_thread.id, | ||
"error": None, | ||
"input": {}, | ||
"output": {}, | ||
"metadata": {}, | ||
"tags": [], | ||
"parentId": None, | ||
"createdAt": "2023-01-01T00:00:00Z", | ||
"startTime": "2023-01-01T00:00:00Z", | ||
"endTime": "2023-01-01T00:00:00Z", | ||
"generation": {}, | ||
"scores": [], | ||
"attachments": [], | ||
"rootRunId": None, | ||
} | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def literal_test_user(test_user: User): | ||
return LiteralUser( | ||
id=str(uuid.uuid4()), | ||
created_at=datetime.datetime.now().isoformat(), | ||
identifier=test_user.identifier, | ||
metadata=test_user.metadata, | ||
) | ||
|
||
|
||
async def test_get_user( | ||
literal_data_layer: LiteralDataLayer, | ||
mock_literal_client: Mock, | ||
literal_test_user: LiteralUser, | ||
persisted_test_user: PersistedUser, | ||
): | ||
mock_literal_client.api.get_user.return_value = literal_test_user | ||
|
||
user = await literal_data_layer.get_user("test_user_id") | ||
|
||
assert user is not None | ||
assert user.id == literal_test_user.id | ||
assert user.identifier == literal_test_user.identifier | ||
|
||
mock_literal_client.api.get_user.assert_awaited_once_with(identifier="test_user_id") | ||
|
||
|
||
async def test_get_user_not_found( | ||
literal_data_layer: LiteralDataLayer, mock_literal_client: Mock | ||
): | ||
mock_literal_client.api.get_user.return_value = None | ||
|
||
user = await literal_data_layer.get_user("non_existent_user_id") | ||
|
||
assert user is None | ||
mock_literal_client.api.get_user.assert_awaited_once_with( | ||
identifier="non_existent_user_id" | ||
) | ||
|
||
|
||
async def test_create_user_not_existing( | ||
literal_data_layer: LiteralDataLayer, | ||
mock_literal_client: Mock, | ||
test_user: User, | ||
literal_test_user: LiteralUser, | ||
): | ||
mock_literal_client.api.get_user.return_value = None | ||
mock_literal_client.api.create_user.return_value = literal_test_user | ||
|
||
persisted_user = await literal_data_layer.create_user(test_user) | ||
|
||
mock_literal_client.api.create_user.assert_awaited_once_with( | ||
identifier=test_user.identifier, metadata=test_user.metadata | ||
) | ||
|
||
assert persisted_user is not None | ||
assert isinstance(persisted_user, PersistedUser) | ||
assert persisted_user.id == literal_test_user.id | ||
assert persisted_user.identifier == literal_test_user.identifier | ||
|
||
|
||
async def test_create_user_update_existing( | ||
literal_data_layer: LiteralDataLayer, | ||
mock_literal_client: Mock, | ||
test_user: User, | ||
literal_test_user: LiteralUser, | ||
persisted_test_user: PersistedUser, | ||
): | ||
mock_literal_client.api.get_user.return_value = literal_test_user | ||
|
||
persisted_user = await literal_data_layer.create_user(test_user) | ||
|
||
mock_literal_client.api.create_user.assert_not_called() | ||
mock_literal_client.api.update_user.assert_awaited_once_with( | ||
id=literal_test_user.id, metadata=test_user.metadata | ||
) | ||
|
||
assert persisted_user is not None | ||
assert isinstance(persisted_user, PersistedUser) | ||
assert persisted_user.id == literal_test_user.id | ||
assert persisted_user.identifier == literal_test_user.identifier | ||
|
||
|
||
async def test_create_user_id_none( | ||
literal_data_layer: LiteralDataLayer, | ||
mock_literal_client: Mock, | ||
test_user: User, | ||
literal_test_user: LiteralUser, | ||
): | ||
"""Weird edge case; persisted user without an id. Do we need this!??""" | ||
|
||
literal_test_user.id = None | ||
mock_literal_client.api.get_user.return_value = literal_test_user | ||
|
||
persisted_user = await literal_data_layer.create_user(test_user) | ||
|
||
mock_literal_client.api.create_user.assert_not_called() | ||
mock_literal_client.api.update_user.assert_not_called() | ||
|
||
assert persisted_user is not None | ||
assert isinstance(persisted_user, PersistedUser) | ||
assert persisted_user.id == "" | ||
assert persisted_user.identifier == literal_test_user.identifier | ||
|
||
|
||
async def test_update_thread( | ||
literal_data_layer: LiteralDataLayer, | ||
mock_literal_client: Mock, | ||
test_thread: LiteralThread, | ||
): | ||
await literal_data_layer.update_thread(test_thread.id, name=test_thread.name) | ||
|
||
mock_literal_client.api.upsert_thread.assert_awaited_once_with( | ||
id=test_thread.id, | ||
name=test_thread.name, | ||
participant_id=None, | ||
metadata=None, | ||
tags=None, | ||
) | ||
|
||
|
||
async def test_get_thread_author( | ||
literal_data_layer, mock_literal_client: Mock, test_thread: LiteralThread | ||
): | ||
test_thread.participant_identifier = "test_user_identifier" | ||
mock_literal_client.api.get_thread.return_value = test_thread | ||
|
||
author = await literal_data_layer.get_thread_author(test_thread.id) | ||
|
||
assert author == "test_user_identifier" | ||
mock_literal_client.api.get_thread.assert_awaited_once_with(id=test_thread.id) | ||
|
||
|
||
async def test_get_thread( | ||
literal_data_layer: LiteralDataLayer, | ||
mock_literal_client: Mock, | ||
test_thread: LiteralThread, | ||
test_step: LiteralStep, | ||
): | ||
assert isinstance(test_thread.steps, list) | ||
test_thread.steps.append(test_step) | ||
mock_literal_client.api.get_thread.return_value = test_thread | ||
|
||
thread = await literal_data_layer.get_thread(test_thread.id) | ||
mock_literal_client.api.get_thread.assert_awaited_once_with(id=test_thread.id) | ||
|
||
assert thread is not None | ||
assert thread["id"] == test_thread.id | ||
assert thread["name"] == test_thread.name | ||
assert thread["steps"] | ||
|
||
|
||
async def test_get_thread_non_existing( | ||
literal_data_layer: LiteralDataLayer, mock_literal_client: Mock | ||
): | ||
mock_literal_client.api.get_thread.return_value = None | ||
|
||
thread = await literal_data_layer.get_thread("non_existent_thread_id") | ||
mock_literal_client.api.get_thread.assert_awaited_once_with( | ||
id="non_existent_thread_id" | ||
) | ||
|
||
assert thread is None | ||
|
||
|
||
async def test_delete_thread( | ||
literal_data_layer: LiteralDataLayer, | ||
mock_literal_client: Mock, | ||
test_thread: LiteralThread, | ||
): | ||
await literal_data_layer.delete_thread(test_thread.id) | ||
|
||
mock_literal_client.api.delete_thread.assert_awaited_once_with(id=test_thread.id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters