Skip to content

Commit

Permalink
[ENH]: add distributed impl for ListDatabases
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Jan 10, 2025
1 parent aa27c47 commit 4b233ff
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 10 deletions.
23 changes: 21 additions & 2 deletions chromadb/db/impl/grpc/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from typing import List, Optional, Sequence, Tuple, Union, cast
from uuid import UUID
from overrides import overrides
Expand Down Expand Up @@ -27,6 +26,7 @@
GetDatabaseRequest,
GetSegmentsRequest,
GetTenantRequest,
ListDatabasesRequest,
UpdateCollectionRequest,
UpdateSegmentRequest,
)
Expand Down Expand Up @@ -133,7 +133,26 @@ def list_databases(
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
raise NotImplementedError()
try:
request = ListDatabasesRequest(limit=limit, offset=offset, tenant=tenant)
response = self._sys_db_stub.ListDatabases(
request, timeout=self._request_timeout_seconds
)
results: List[Database] = []
for proto_database in response.databases:
results.append(
Database(
id=UUID(hex=proto_database.id),
name=proto_database.name,
tenant=proto_database.tenant,
)
)
return results
except grpc.RpcError as e:
logger.info(
f"Failed to list databases for tenant {tenant} due to error: {e}"
)
raise InternalError()

@overrides
def create_tenant(self, name: str) -> None:
Expand Down
9 changes: 1 addition & 8 deletions chromadb/test/api/test_list_databases.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from typing import Dict, List
from hypothesis import given
import pytest
from chromadb.test.conftest import NOT_CLUSTER_ONLY, ClientFactories
from chromadb.test.conftest import ClientFactories
import hypothesis.strategies as st


def test_list_databases(client_factories: ClientFactories) -> None:
if not NOT_CLUSTER_ONLY:
pytest.skip("This API is not yet supported by distributed")

admin_client = client_factories.create_admin_client_from_system()

for i in range(10):
Expand Down Expand Up @@ -66,9 +62,6 @@ def test_list_databases_with_limit_offset(
tenants_and_databases: Dict[str, List[str]],
client_factories: ClientFactories,
) -> None:
if not NOT_CLUSTER_ONLY:
pytest.skip("This API is not yet supported by distributed")

client = client_factories.create_client()
client.reset()

Expand Down
8 changes: 8 additions & 0 deletions go/pkg/sysdb/coordinator/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ func (s *Coordinator) GetDatabase(ctx context.Context, getDatabase *model.GetDat
return database, nil
}

func (s *Coordinator) ListDatabases(ctx context.Context, listDatabases *model.ListDatabases) ([]*model.Database, error) {
databases, err := s.catalog.ListDatabases(ctx, listDatabases, listDatabases.Ts)
if err != nil {
return nil, err
}
return databases, nil
}

func (s *Coordinator) CreateTenant(ctx context.Context, createTenant *model.CreateTenant) (*model.Tenant, error) {
tenant, err := s.catalog.CreateTenant(ctx, createTenant, createTenant.Ts)
if err != nil {
Expand Down
7 changes: 7 additions & 0 deletions go/pkg/sysdb/coordinator/model/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ type GetDatabase struct {
Tenant string
Ts types.Timestamp
}

type ListDatabases struct {
Limit *int32
Offset *int32
Tenant string
Ts types.Timestamp
}
12 changes: 12 additions & 0 deletions go/pkg/sysdb/coordinator/table_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,18 @@ func (tc *Catalog) GetDatabases(ctx context.Context, getDatabase *model.GetDatab
return result[0], nil
}

func (tc *Catalog) ListDatabases(ctx context.Context, listDatabases *model.ListDatabases, ts types.Timestamp) ([]*model.Database, error) {
databases, err := tc.metaDomain.DatabaseDb(ctx).ListDatabases(listDatabases.Limit, listDatabases.Offset, listDatabases.Tenant)
if err != nil {
return nil, err
}
result := make([]*model.Database, 0, len(databases))
for _, database := range databases {
result = append(result, convertDatabaseToModel(database))
}
return result, nil
}

func (tc *Catalog) GetAllDatabases(ctx context.Context, ts types.Timestamp) ([]*model.Database, error) {
databases, err := tc.metaDomain.DatabaseDb(ctx).GetAllDatabases()
if err != nil {
Expand Down
25 changes: 25 additions & 0 deletions go/pkg/sysdb/grpc/tenant_database_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,31 @@ func (s *Server) GetDatabase(ctx context.Context, req *coordinatorpb.GetDatabase
return res, nil
}

func (s *Server) ListDatabases(ctx context.Context, req *coordinatorpb.ListDatabasesRequest) (*coordinatorpb.ListDatabasesResponse, error) {
res := &coordinatorpb.ListDatabasesResponse{}
listDatabases := &model.ListDatabases{
Limit: req.Limit,
Offset: req.Offset,
Tenant: req.GetTenant(),
}
databases, err := s.coordinator.ListDatabases(ctx, listDatabases)
if err != nil {
log.Error("error ListDatabases", zap.String("request", req.String()), zap.Error(err))
if err == common.ErrTenantNotFound {
return res, grpcutils.BuildNotFoundGrpcError(err.Error())
}
return res, grpcutils.BuildInternalGrpcError(err.Error())
}
for _, database := range databases {
res.Databases = append(res.Databases, &coordinatorpb.Database{
Id: database.ID,
Name: database.Name,
Tenant: database.Tenant,
})
}
return res, nil
}

func (s *Server) CreateTenant(ctx context.Context, req *coordinatorpb.CreateTenantRequest) (*coordinatorpb.CreateTenantResponse, error) {
res := &coordinatorpb.CreateTenantResponse{}
createTenant := &model.CreateTenant{
Expand Down
22 changes: 22 additions & 0 deletions go/pkg/sysdb/metastore/db/dao/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,28 @@ func (s *databaseDb) GetAllDatabases() ([]*dbmodel.Database, error) {
return databases, nil
}

func (s *databaseDb) ListDatabases(limit *int32, offset *int32, tenantID string) ([]*dbmodel.Database, error) {
var databases []*dbmodel.Database
query := s.db.Table("databases").
Select("databases.id, databases.name, databases.tenant_id").
Where("databases.tenant_id = ?", tenantID).
Order("databases.created_at ASC")

if limit != nil {
query = query.Limit(int(*limit))
}

if offset != nil {
query = query.Offset(int(*offset))
}

if err := query.Find(&databases).Error; err != nil {
log.Error("ListDatabases", zap.Error(err))
return nil, err
}
return databases, nil
}

func (s *databaseDb) GetDatabases(tenantID string, databaseName string) ([]*dbmodel.Database, error) {
var databases []*dbmodel.Database
query := s.db.Table("databases").
Expand Down
1 change: 1 addition & 0 deletions go/pkg/sysdb/metastore/db/dbmodel/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func (v Database) TableName() string {
type IDatabaseDb interface {
GetAllDatabases() ([]*Database, error)
GetDatabases(tenantID string, databaseName string) ([]*Database, error)
ListDatabases(limit *int32, offset *int32, tenantID string) ([]*Database, error)
Insert(in *Database) error
DeleteAll() error
}
12 changes: 12 additions & 0 deletions idl/chromadb/proto/coordinator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ message GetDatabaseResponse {
reserved "status";
}

message ListDatabasesRequest {
string tenant = 1;
optional int32 limit = 2;
optional int32 offset = 3;
}

message ListDatabasesResponse {
repeated Database databases = 1;
reserved "status";
}

message CreateTenantRequest {
string name = 2; // Names are globally unique
}
Expand Down Expand Up @@ -346,6 +357,7 @@ message RestoreCollectionResponse {
service SysDB {
rpc CreateDatabase(CreateDatabaseRequest) returns (CreateDatabaseResponse) {}
rpc GetDatabase(GetDatabaseRequest) returns (GetDatabaseResponse) {}
rpc ListDatabases(ListDatabasesRequest) returns (ListDatabasesResponse) {}
rpc CreateTenant(CreateTenantRequest) returns (CreateTenantResponse) {}
rpc GetTenant(GetTenantRequest) returns (GetTenantResponse) {}
rpc CreateSegment(CreateSegmentRequest) returns (CreateSegmentResponse) {}
Expand Down

0 comments on commit 4b233ff

Please sign in to comment.