Skip to content

Commit

Permalink
[ENH]: add API to list all databases for tenant
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Jan 9, 2025
1 parent bf522a5 commit b9bdfe2
Show file tree
Hide file tree
Showing 22 changed files with 893 additions and 242 deletions.
15 changes: 15 additions & 0 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,21 @@ def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
"""
pass

@abstractmethod
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
"""List all databases for a tenant. Raises an error if the tenant does not exist.
Args:
tenant: The tenant to list databases for.
"""
pass

@abstractmethod
def create_tenant(self, name: str) -> None:
"""Create a new tenant. Raises an error if the tenant already exists.
Expand Down
15 changes: 15 additions & 0 deletions chromadb/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,21 @@ async def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Databas
"""
pass

@abstractmethod
async def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
"""List all databases for a tenant. Raises an error if the tenant does not exist.
Args:
tenant: The tenant to list databases for.
"""
pass

@abstractmethod
async def create_tenant(self, name: str) -> None:
"""Create a new tenant. Raises an error if the tenant already exists.
Expand Down
11 changes: 11 additions & 0 deletions chromadb/api/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,17 @@ async def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None
async def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
return await self._server.get_database(name=name, tenant=tenant)

@override
async def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
return await self._server.list_databases(
limit=limit, offset=offset, tenant=tenant
)

@override
async def create_tenant(self, name: str) -> None:
return await self._server.create_tenant(name=name)
Expand Down
24 changes: 24 additions & 0 deletions chromadb/api/async_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,30 @@ async def get_database(
id=response["id"], name=response["name"], tenant=response["tenant"]
)

@trace_method("AsyncFastAPI.list_databases", OpenTelemetryGranularity.OPERATION)
@override
async def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
response = await self._make_request(
"get",
f"/tenants/{tenant}/databases",
params=BaseHTTPClient._clean_params(
{
"limit": limit,
"offset": offset,
}
),
)

return [
Database(id=db["id"], name=db["name"], tenant=db["tenant"])
for db in response
]

@trace_method("AsyncFastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
async def create_tenant(self, name: str) -> None:
Expand Down
9 changes: 9 additions & 0 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,15 @@ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
return self._server.get_database(name=name, tenant=tenant)

@override
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
return self._server.list_databases(limit, offset, tenant=tenant)

@override
def create_tenant(self, name: str) -> None:
return self._server.create_tenant(name=name)
Expand Down
25 changes: 25 additions & 0 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,31 @@ def get_database(
id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"]
)

@trace_method("FastAPI.list_databases", OpenTelemetryGranularity.OPERATION)
@override
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
"""Returns a list of all databases"""
json_databases = self._make_request(
"get",
f"/tenants/{tenant}/databases",
params=BaseHTTPClient._clean_params(
{
"limit": limit,
"offset": offset,
}
),
)
databases = [
Database(id=db["id"], name=db["name"], tenant=db["tenant"])
for db in json_databases
]
return databases

@trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
def create_tenant(self, name: str) -> None:
Expand Down
22 changes: 18 additions & 4 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@ def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database:
return self._sysdb.get_database(name=name, tenant=tenant)

@trace_method("SegmentAPI.list_databases", OpenTelemetryGranularity.OPERATION)
@override
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[t.Database]:
return self._sysdb.list_databases(limit=limit, offset=offset, tenant=tenant)

@trace_method("SegmentAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
def create_tenant(self, name: str) -> None:
Expand Down Expand Up @@ -227,7 +237,7 @@ def create_collection(
id=model.id,
name=model.name,
configuration=model.get_configuration(),
segments=[], # Passing empty till backend changes are deployed.
segments=[], # Passing empty till backend changes are deployed.
metadata=model.metadata,
dimension=None, # This is lazily populated on the first add
get_or_create=get_or_create,
Expand Down Expand Up @@ -894,11 +904,15 @@ def _get_collection(self, collection_id: UUID) -> t.Collection:

@trace_method("SegmentAPI._scan", OpenTelemetryGranularity.ALL)
def _scan(self, collection_id: UUID) -> Scan:
collection_and_segments = self._sysdb.get_collection_with_segments(collection_id)
collection_and_segments = self._sysdb.get_collection_with_segments(
collection_id
)
# For now collection should have exactly one segment per scope:
# - Local scopes: vector, metadata
# - Distributed scopes: vector, metadata, record
scope_to_segment = {segment["scope"]: segment for segment in collection_and_segments["segments"]}
# - Distributed scopes: vector, metadata, record
scope_to_segment = {
segment["scope"]: segment for segment in collection_and_segments["segments"]
}
return Scan(
collection=collection_and_segments["collection"],
knn=scope_to_segment[t.SegmentScope.VECTOR],
Expand Down
1 change: 1 addition & 0 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class AuthzAction(str, Enum):
GET_TENANT = "tenant:get_tenant"
CREATE_DATABASE = "db:create_database"
GET_DATABASE = "db:get_database"
LIST_DATABASES = "db:list_databases"
LIST_COLLECTIONS = "db:list_collections"
COUNT_COLLECTIONS = "db:count_collections"
CREATE_COLLECTION = "db:create_collection"
Expand Down
23 changes: 19 additions & 4 deletions chromadb/db/impl/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
raise NotFoundError()
raise InternalError()

@overrides
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
raise NotImplementedError()

@overrides
def create_tenant(self, name: str) -> None:
try:
Expand Down Expand Up @@ -310,7 +319,9 @@ def delete_collection(
f"Failed to delete collection id {id} for database {database} and tenant {tenant} due to error: {e}"
)
e = cast(grpc.Call, e)
logger.error(f"Error code: {e.code()}, NotFoundError: {grpc.StatusCode.NOT_FOUND}")
logger.error(
f"Error code: {e.code()}, NotFoundError: {grpc.StatusCode.NOT_FOUND}"
)
if e.code() == grpc.StatusCode.NOT_FOUND:
raise NotFoundError()
raise InternalError()
Expand Down Expand Up @@ -367,13 +378,17 @@ def get_collections(
raise InternalError()

@overrides
def get_collection_with_segments(self, collection_id: UUID) -> CollectionAndSegments:
def get_collection_with_segments(
self, collection_id: UUID
) -> CollectionAndSegments:
try:
request = GetCollectionWithSegmentsRequest(id=collection_id.hex)
response: GetCollectionWithSegmentsResponse = self._sys_db_stub.GetCollectionWithSegments(request)
response: GetCollectionWithSegmentsResponse = (
self._sys_db_stub.GetCollectionWithSegments(request)
)
return CollectionAndSegments(
collection=from_proto_collection(response.collection),
segments=[from_proto_segment(segment) for segment in response.segments]
segments=[from_proto_segment(segment) for segment in response.segments],
)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.NOT_FOUND:
Expand Down
44 changes: 40 additions & 4 deletions chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import sys
from typing import Optional, Sequence, Any, Tuple, cast, Dict, Union, Set
from uuid import UUID
from overrides import override
Expand All @@ -15,7 +16,11 @@
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System
from chromadb.db.base import Cursor, SqlDB, ParameterValue, get_sql
from chromadb.db.system import SysDB
from chromadb.errors import InvalidCollectionException, NotFoundError, UniqueConstraintError
from chromadb.errors import (
InvalidCollectionException,
NotFoundError,
UniqueConstraintError,
)
from chromadb.telemetry.opentelemetry import (
add_attributes_to_current_span,
OpenTelemetryClient,
Expand Down Expand Up @@ -111,6 +116,37 @@ def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
tenant=tenant,
)

@override
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
with self.tx() as cur:
databases = Table("databases")
q = (
self.querybuilder()
.from_(databases)
.select(databases.id, databases.name)
.where(databases.tenant_id == ParameterValue(tenant))
.offset(offset)
.limit(
sys.maxsize if limit is None else limit
) # SQLite requires that a limit is provided to use offset
.orderby(databases.created_at)
)
sql, params = get_sql(q, self.parameter_format())
rows = cur.execute(sql, params).fetchall()
return [
Database(
id=cast(UUID, self.uuid_from_db(row[0])),
name=row[1],
tenant=tenant,
)
for row in rows
]

@override
def create_tenant(self, name: str) -> None:
with self.tx() as cur:
Expand Down Expand Up @@ -195,15 +231,13 @@ def create_segment_with_tx(self, cur: Cursor, segment: Segment) -> None:
logger.error(f"Error inserting segment metadata: {e}")
raise


# TODO(rohit): Investigate and remove this method completely.
@trace_method("SqlSysDB.create_segment", OpenTelemetryGranularity.ALL)
@override
def create_segment(self, segment: Segment) -> None:
with self.tx() as cur:
self.create_segment_with_tx(cur, segment)


@trace_method("SqlSysDB.create_collection", OpenTelemetryGranularity.ALL)
@override
def create_collection(
Expand Down Expand Up @@ -491,7 +525,9 @@ def get_collections(
return collections

@override
def get_collection_with_segments(self, collection_id: UUID) -> CollectionAndSegments:
def get_collection_with_segments(
self, collection_id: UUID
) -> CollectionAndSegments:
collections = self.get_collections(id=collection_id)
if len(collections) == 0:
raise InvalidCollectionException(
Expand Down
15 changes: 12 additions & 3 deletions chromadb/db/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
exist."""
pass

@abstractmethod
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
"""List all databases for a tenant."""
pass

@abstractmethod
def create_tenant(self, name: str) -> None:
"""Create a new tenant in the System database. The name must be unique.
Expand Down Expand Up @@ -131,11 +141,10 @@ def get_collections(

@abstractmethod
def get_collection_with_segments(
self,
collection_id: UUID
self, collection_id: UUID
) -> CollectionAndSegments:
"""Get a consistent snapshot of a collection by id. This will return a collection with segment
information that matches the collection version and log position.
information that matches the collection version and log position.
"""
pass

Expand Down
Loading

0 comments on commit b9bdfe2

Please sign in to comment.