Skip to content

Commit

Permalink
[ENH]: Enable cluster property test for collection (#2004)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Enable and fix property test for collection and sysdb
  • Loading branch information
nicolasgere authored Apr 25, 2024
1 parent 99381f2 commit ef24ce2
Show file tree
Hide file tree
Showing 20 changed files with 2,220 additions and 368 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/chroma-cluster-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ jobs:
python: ['3.8']
platform: ['16core-64gb-ubuntu-latest']
testfile: ["chromadb/test/db/test_system.py",
"chromadb/test/property/test_collections.py",
"chromadb/test/property/test_collections_with_database_tenant.py",
"chromadb/test/ingest/test_producer_consumer.py",
"chromadb/test/segment/distributed/test_memberlist_provider.py",
"chromadb/test/test_logservice.py"]
Expand Down
2 changes: 1 addition & 1 deletion Tiltfile
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ k8s_resource(
)

# Production Chroma
k8s_resource('postgres', resource_deps=['k8s_setup', 'namespace'], labels=["infrastructure"])
k8s_resource('postgres', resource_deps=['k8s_setup', 'namespace'], labels=["infrastructure"], port_forwards='5432:5432')
k8s_resource('sysdb-migration', resource_deps=['postgres', 'namespace'], labels=["infrastructure"])
k8s_resource('logservice-migration', resource_deps=['postgres', 'namespace'], labels=["infrastructure"])
k8s_resource('logservice', resource_deps=['sysdb-migration'], labels=["chroma"], port_forwards='50052:50051')
Expand Down
4 changes: 4 additions & 0 deletions chromadb/db/impl/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ def get_collections(
name=name,
tenant=tenant,
database=database,
limit=limit,
offset=offset,
)
response: GetCollectionsResponse = self._sys_db_stub.GetCollections(request)
results: List[Collection] = []
Expand Down Expand Up @@ -293,6 +295,8 @@ def update_collection(
response = self._sys_db_stub.UpdateCollection(request)
if response.status.code == 404:
raise NotFoundError()
if response.status.code == 409:
raise UniqueConstraintError()

def reset_and_wait_for_ready(self) -> None:
self._sys_db_stub.ResetState(Empty(), wait_for_ready=True)
2 changes: 1 addition & 1 deletion chromadb/proto/chroma_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion chromadb/proto/coordinator_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion chromadb/proto/logservice_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion chromadb/segment/impl/manager/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def create_segments(self, collection: Collection) -> Sequence[Segment]:

@override
def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
raise NotImplementedError()
segments = self._sysdb.get_segments(collection=collection_id)
return [s["id"] for s in segments]

@trace_method(
"DistributedSegmentManager.get_segment",
Expand Down
6 changes: 6 additions & 0 deletions go/migrations/20240411201006.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- Drop index "uni_collections_name" from table: "collections"
DROP INDEX "public"."uni_collections_name";
-- Create index "idx_name" to table: "collections"
CREATE UNIQUE INDEX "idx_name" ON "public"."collections" ("name", "database_id");
-- Drop "record_logs" table
DROP TABLE "public"."record_logs";
3 changes: 2 additions & 1 deletion go/migrations/atlas.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
h1:9rYxc6RcMJ3Cd4SPZoQ+T6XAUZ7yyWCEnwoRkj1af3c=
h1:0fWoYRU+wWRTteHvXVV8mOc+AJYEJzDFuH4c5hLWKR0=
20240313233558.sql h1:Gv0TiSYsqGoOZ2T2IWvX4BOasauxool8PrBOIjmmIdg=
20240321194713.sql h1:kVkNpqSFhrXGVGFFvL7JdK3Bw31twFcEhI6A0oCFCkg=
20240327075032.sql h1:nlr2J74XRU8erzHnKJgMr/tKqJxw9+R6RiiEBuvuzgo=
20240327172649.sql h1:UUGo6AzWXKLcpYVd5qH6Hv9jpHNV86z42o6ft5OR0zU=
20240411201006.sql h1:jjzYJPzDVTxQAvOI7gRtNTiZJHy1Hpw5urP8EzqxgUk=
2 changes: 1 addition & 1 deletion go/pkg/coordinator/apis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ func (suite *APIsTestSuite) TestCreateDatabaseWithTenants() {
// results are returned
result, err = suite.coordinator.GetCollections(ctx, types.NilUniqueID(), nil, newTenantName, suite.databaseName, nil, nil)
suite.NoError(err)
suite.Nil(result)
suite.Equal(0, len(result))

// clean up
err = dao.CleanUpTestTenant(suite.db, newTenantName)
Expand Down
7 changes: 6 additions & 1 deletion go/pkg/coordinator/grpc/collection_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,14 @@ func (s *Server) UpdateCollection(ctx context.Context, req *coordinatorpb.Update
}

_, err = s.coordinator.UpdateCollection(ctx, updateCollection)

if err != nil {
log.Error("error updating collection", zap.Error(err))
res.Status = failResponseWithError(err, errorCode)
if err == common.ErrCollectionUniqueConstraintViolation {
res.Status = failResponseWithError(err, 409)
} else {
res.Status = failResponseWithError(err, errorCode)
}
return res, nil
}

Expand Down
6 changes: 4 additions & 2 deletions go/pkg/coordinator/grpc/tenant_database_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package grpc

import (
"context"
"errors"
"github.com/chroma-core/chroma/go/pkg/grpcutils"
"github.com/pingcap/log"
"go.uber.org/zap"
Expand All @@ -21,10 +22,11 @@ func (s *Server) CreateDatabase(ctx context.Context, req *coordinatorpb.CreateDa
}
_, err := s.coordinator.CreateDatabase(ctx, createDatabase)
if err != nil {
if err == common.ErrDatabaseUniqueConstraintViolation {
if errors.Is(err, common.ErrDatabaseUniqueConstraintViolation) {
res.Status = failResponseWithError(err, 409)
return res, nil
return res, err
}

res.Status = failResponseWithError(err, errorCode)
return res, nil
}
Expand Down
150 changes: 57 additions & 93 deletions go/pkg/metastore/db/dao/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package dao
import (
"database/sql"
"errors"
"strings"

"github.com/chroma-core/chroma/go/pkg/common"
"github.com/jackc/pgx/v5/pgconn"
"gorm.io/gorm/clause"
Expand All @@ -26,57 +24,38 @@ func (s *collectionDb) DeleteAll() error {
return s.db.Where("1 = 1").Delete(&dbmodel.Collection{}).Error
}

func (s *collectionDb) GetCollections(id *string, name *string, tenantID string, databaseName string, limit *int32, offset *int32) ([]*dbmodel.CollectionAndMetadata, error) {
var getCollectionInput strings.Builder
getCollectionInput.WriteString("GetCollections input: ")

var collections []*dbmodel.CollectionAndMetadata

func (s *collectionDb) GetCollections(id *string, name *string, tenantID string, databaseName string, limit *int32, offset *int32) (collectionWithMetdata []*dbmodel.CollectionAndMetadata, err error) {
var collections []*dbmodel.Collection
query := s.db.Table("collections").
Select("collections.id, collections.log_position, collections.version, collections.name, collections.dimension, collections.database_id, collections.created_at, databases.name, databases.tenant_id, collection_metadata.key, collection_metadata.str_value, collection_metadata.int_value, collection_metadata.float_value").
Joins("LEFT JOIN collection_metadata ON collections.id = collection_metadata.collection_id").
Select("collections.id, collections.log_position, collections.version, collections.name, collections.dimension, collections.database_id, databases.name, databases.tenant_id").
Joins("INNER JOIN databases ON collections.database_id = databases.id").
Order("collections.id asc")
if limit != nil {
query = query.Limit(int(*limit))
getCollectionInput.WriteString("limit: " + string(*limit) + ", ")
}

if offset != nil {
query = query.Offset(int(*offset))
getCollectionInput.WriteString("offset: " + string(*offset) + ", ")
}
Order("collections.created_at ASC")

if databaseName != "" {
query = query.Where("databases.name = ?", databaseName)
getCollectionInput.WriteString("databases.name: " + databaseName + ", ")
}

if tenantID != "" {
query = query.Where("databases.tenant_id = ?", tenantID)
getCollectionInput.WriteString("databases.tenant_id: " + tenantID + ", ")
}

if id != nil {
query = query.Where("collections.id = ?", *id)
getCollectionInput.WriteString("collections.id: " + *id + ", ")
}
if name != nil {
query = query.Where("collections.name = ?", *name)
getCollectionInput.WriteString("collections.name: " + *name + ", ")
}
log.Info(getCollectionInput.String())

if limit != nil {
query = query.Limit(int(*limit))
}
if offset != nil {
query = query.Offset(int(*offset))

}
rows, err := query.Rows()
if err != nil {
return nil, err
}
defer rows.Close()

var currentCollectionID string = ""
var metadata []*dbmodel.CollectionMetadata
var currentCollection *dbmodel.CollectionAndMetadata

collectionWithMetdata = make([]*dbmodel.CollectionAndMetadata, 0, len(collections))
for rows.Next() {
var (
collectionID string
Expand All @@ -88,78 +67,46 @@ func (s *collectionDb) GetCollections(id *string, name *string, tenantID string,
collectionCreatedAt sql.NullTime
databaseName string
databaseTenantID string
key sql.NullString
strValue sql.NullString
intValue sql.NullInt64
floatValue sql.NullFloat64
)

err := rows.Scan(&collectionID, &logPosition, &version, &collectionName, &collectionDimension, &collectionDatabaseID, &collectionCreatedAt, &databaseName, &databaseTenantID, &key, &strValue, &intValue, &floatValue)
err := rows.Scan(&collectionID, &logPosition, &version, &collectionName, &collectionDimension, &collectionDatabaseID, &databaseName, &databaseTenantID)
if err != nil {
log.Error("scan collection failed", zap.Error(err))
return nil, err
}
if collectionID != currentCollectionID {
currentCollectionID = collectionID
metadata = nil

currentCollection = &dbmodel.CollectionAndMetadata{
Collection: &dbmodel.Collection{
ID: collectionID,
Name: &collectionName,
DatabaseID: collectionDatabaseID,
LogPosition: logPosition,
Version: version,
},
CollectionMetadata: metadata,
TenantID: databaseTenantID,
DatabaseName: databaseName,
}
if collectionDimension.Valid {
currentCollection.Collection.Dimension = &collectionDimension.Int32
} else {
currentCollection.Collection.Dimension = nil
}
if collectionCreatedAt.Valid {
currentCollection.Collection.CreatedAt = collectionCreatedAt.Time
}

if currentCollectionID != "" {
collections = append(collections, currentCollection)
}
collection := &dbmodel.Collection{
ID: collectionID,
Name: &collectionName,
DatabaseID: collectionDatabaseID,
LogPosition: logPosition,
Version: version,
}

collectionMetadata := &dbmodel.CollectionMetadata{
CollectionID: collectionID,
if collectionDimension.Valid {
collection.Dimension = &collectionDimension.Int32
}

if key.Valid {
collectionMetadata.Key = &key.String
} else {
collectionMetadata.Key = nil
if collectionCreatedAt.Valid {
collection.CreatedAt = collectionCreatedAt.Time
}

if strValue.Valid {
collectionMetadata.StrValue = &strValue.String
} else {
collectionMetadata.StrValue = nil
}
if intValue.Valid {
collectionMetadata.IntValue = &intValue.Int64
} else {
collectionMetadata.IntValue = nil
}
if floatValue.Valid {
collectionMetadata.FloatValue = &floatValue.Float64
} else {
collectionMetadata.FloatValue = nil
collectionWithMetdata = append(collectionWithMetdata, &dbmodel.CollectionAndMetadata{
Collection: collection,
TenantID: databaseTenantID,
DatabaseName: databaseName,
})
}
rows.Close()
for _, collection := range collectionWithMetdata {
var metadata []*dbmodel.CollectionMetadata
err = s.db.Where("collection_id = ?", collection.Collection.ID).Find(&metadata).Error
if err != nil {
log.Error("get collection metadata failed", zap.Error(err))
return nil, err
}

metadata = append(metadata, collectionMetadata)
currentCollection.CollectionMetadata = metadata
collection.CollectionMetadata = metadata
}
log.Info("collections", zap.Any("collections", collections))
return collections, nil

return
}

func (s *collectionDb) DeleteCollectionByID(collectionID string) (int, error) {
Expand Down Expand Up @@ -203,7 +150,24 @@ func generateCollectionUpdatesWithoutID(in *dbmodel.Collection) map[string]inter
func (s *collectionDb) Update(in *dbmodel.Collection) error {
log.Info("update collection", zap.Any("collection", in))
updates := generateCollectionUpdatesWithoutID(in)
return s.db.Model(&dbmodel.Collection{}).Where("id = ?", in.ID).Updates(updates).Error
err := s.db.Model(&dbmodel.Collection{}).Where("id = ?", in.ID).Updates(updates).Error
if err != nil {
log.Error("create collection failed", zap.Error(err))
var pgErr *pgconn.PgError
ok := errors.As(err, &pgErr)
if ok {
log.Error("Postgres Error")
switch pgErr.Code {
case "23505":
log.Error("collection already exists")
return common.ErrCollectionUniqueConstraintViolation
default:
return err
}
}
return err
}
return nil
}

func (s *collectionDb) UpdateLogPositionAndVersion(collectionID string, logPosition int64, currentCollectionVersion int32) (int32, error) {
Expand Down
5 changes: 5 additions & 0 deletions go/pkg/metastore/db/dao/collection_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ func (s *collectionMetadataDb) DeleteAll() error {
return s.db.Where("1 = 1").Delete(&dbmodel.CollectionMetadata{}).Error
}

func (s *collectionMetadataDb) GetForCollection(collectionID string) (metadata []dbmodel.CollectionMetadata, err error) {
err = s.db.Where("collection_id = ?", collectionID).Find(&metadata).Error
return
}

func (s *collectionMetadataDb) DeleteByCollectionID(collectionID string) (int, error) {
var metadata []dbmodel.CollectionMetadata
err := s.db.Clauses(clause.Returning{}).Where("collection_id = ?", collectionID).Delete(&metadata).Error
Expand Down
2 changes: 1 addition & 1 deletion go/pkg/metastore/db/dao/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollections() {
offset = int32(2)
collections, err = suite.collectionDb.GetCollections(nil, nil, suite.tenantName, suite.databaseName, &limit, &offset)
suite.NoError(err)
suite.Nil(collections)
suite.Equal(len(collections), 0)

// clean up
err = CleanUpTestCollection(suite.db, collectionID)
Expand Down
Loading

0 comments on commit ef24ce2

Please sign in to comment.