Skip to content

Commit

Permalink
Improve test
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Jan 9, 2025
1 parent 1ce628b commit 4b381c2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 27 deletions.
2 changes: 1 addition & 1 deletion chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def list_databases(
.where(databases.tenant_id == ParameterValue(tenant))
.limit(limit)
.offset(offset)
.orderby(databases.id)
.orderby(databases.created_at)
)
sql, params = get_sql(q, self.parameter_format())
rows = cur.execute(sql, params).fetchall()
Expand Down
75 changes: 49 additions & 26 deletions chromadb/test/api/test_list_databases.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, List
from hypothesis import given
import pytest
from chromadb.test.conftest import NOT_CLUSTER_ONLY, ClientFactories
Expand All @@ -20,30 +21,48 @@ def test_list_databases(client_factories: ClientFactories) -> None:
assert any(d["name"] == f"test_list_databases_{i}" for d in databases)


def test_list_databases_is_tenant_scoped(client_factories: ClientFactories) -> None:
if not NOT_CLUSTER_ONLY:
pytest.skip("This API is not yet supported by distributed")
@st.composite
def tenants_and_databases_st(
draw: st.DrawFn, max_tenants: int, max_databases: int
) -> Dict[str, List[str]]:
"""Generates a set of random tenants and databases. Each database is assigned to a random tenant. Returns a dictionary where the key is the tenant name and the value is a list of database names for that tenant."""
num_tenants = draw(st.integers(min_value=1, max_value=max_tenants))
num_databases = draw(st.integers(min_value=0, max_value=max_databases))

admin_client = client_factories.create_admin_client_from_system()
admin_client.create_tenant("new_tenant")
database_i_to_tenant_i = draw(
st.lists(
st.integers(min_value=0, max_value=num_tenants - 1),
min_size=num_databases,
max_size=num_databases,
)
)

for i in range(10):
admin_client.create_database(f"test_list_databases_is_tenant_scoped_{i}")
tenants = [f"tenant_{i}" for i in range(num_tenants)]
databases = [f"database_{i}" for i in range(num_databases)]

result: Dict[str, List[str]] = {}
for database_i, tenant_i in enumerate(database_i_to_tenant_i):
tenant = tenants[tenant_i]
database = databases[database_i]

if tenant not in result:
result[tenant] = []

databases = admin_client.list_databases(None, None, "new_tenant")
assert len(databases) == 0
result[tenant].append(database)

databases_for_default_tenant = admin_client.list_databases()
assert len(databases_for_default_tenant) >= 10
return result


@given(
limit=st.integers(min_value=1, max_value=5),
offset=st.integers(min_value=0, max_value=5),
num_databases=st.integers(min_value=0, max_value=10),
limit=st.integers(min_value=1, max_value=10),
offset=st.integers(min_value=0, max_value=10),
tenants_and_databases=tenants_and_databases_st(max_tenants=10, max_databases=10),
)
def test_list_databases_with_limit_offset(
limit: int, offset: int, num_databases: int, client_factories: ClientFactories
limit: int,
offset: int,
tenants_and_databases: Dict[str, List[str]],
client_factories: ClientFactories,
) -> None:
if not NOT_CLUSTER_ONLY:
pytest.skip("This API is not yet supported by distributed")
Expand All @@ -53,17 +72,21 @@ def test_list_databases_with_limit_offset(

admin_client = client_factories.create_admin_client_from_system()

for i in range(num_databases):
admin_client.create_database(f"test_list_databases_with_limit_offset_{i}")
for tenant, databases in tenants_and_databases.items():
admin_client.create_tenant(tenant)

sliced_databases = admin_client.list_databases(limit=limit, offset=offset)
for database in databases:
admin_client.create_database(database, tenant)

all_databases = admin_client.list_databases()
expected_databases = list(all_databases[offset : offset + limit])
total_databases = max(num_databases + 1 - offset, 0) # add 1 for default_database
for tenant, all_databases in tenants_and_databases.items():
listed_databases = admin_client.list_databases(
limit=limit, offset=offset, tenant=tenant
)
expected_databases = all_databases[offset : offset + limit]

if limit + offset > num_databases:
assert len(sliced_databases) == total_databases
assert sliced_databases == expected_databases
else:
assert len(sliced_databases) == limit
if limit + offset > len(all_databases):
assert len(listed_databases) == max(len(all_databases) - offset, 0)
assert [d["name"] for d in listed_databases] == expected_databases
else:
assert len(listed_databases) == limit
assert [d["name"] for d in listed_databases] == expected_databases

0 comments on commit 4b381c2

Please sign in to comment.