Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(firestore): minor tweaks and doc for vector search #10583

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions firestore/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ as a query.

iter = client.Collection("States").Documents(ctx)

Firestore supports similarity search over embedding vectors. See [Query.FindNearest]
for details.

# Collection Group Partition Queries

You can partition the documents of a Collection Group allowing for smaller subqueries.
Expand Down
24 changes: 24 additions & 0 deletions firestore/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,30 @@ func ExampleQuery_Snapshots() {
}
}

// This example demonstrates how to use Firestore vector search.
// It assumes that the database has a collection "descriptions"
// in which each document has a field of type Vector32 or Vector64
// called "Embedding":
//
// type Description struct {
// // ...
// Embedding firestore.Vector32
// }
func ExampleQuery_FindNearest() {
ctx := context.Background()
client, err := firestore.NewClient(ctx, "project-id")
if err != nil {
// TODO: Handle error.
}
defer client.Close()

//
q := client.Collection("descriptions").
FindNearest("Embedding", []float32{1, 2, 3}, 5, firestore.DistanceMeasureDotProduct, nil)
iter1 := q.Documents(ctx)
_ = iter1 // TODO: Use iter1.
}

func ExampleDocumentIterator_Next() {
ctx := context.Background()
client, err := firestore.NewClient(ctx, "project-id")
Expand Down
2 changes: 1 addition & 1 deletion firestore/from_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ func createFromProtoValue(vproto *pb.Value, c *Client) (interface{}, error) {
}

// Special handling for vector
return vectorFromProtoValue(vproto)
return vector64FromProtoValue(vproto)
default:
return nil, fmt.Errorf("firestore: unknown value type %T", v)
}
Expand Down
3 changes: 3 additions & 0 deletions firestore/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2356,6 +2356,9 @@ func TestIntegration_NewClientWithDatabase(t *testing.T) {
if testing.Short() {
t.Skip("Integration tests skipped in short mode")
}
if iClient == nil {
t.Skip("Integration test skipped: did not create client")
}
for _, tc := range []struct {
desc string
dbName string
Expand Down
32 changes: 18 additions & 14 deletions firestore/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ type DistanceMeasure int32

const (
// DistanceMeasureEuclidean is used to measures the Euclidean distance between the vectors. See
// [Euclidean] to learn more
// [Euclidean] to learn more.
//
// [Euclidean]: https://en.wikipedia.org/wiki/Euclidean_distance
DistanceMeasureEuclidean DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_EUCLIDEAN)
Expand All @@ -393,33 +393,39 @@ const (
)

// FindNearestOptions are options for a FindNearest vector query.
// At present, there are no options.
type FindNearestOptions struct {
}

// VectorQuery represents a vector query
// VectorQuery represents a query that uses [Query.FindNearest] or [Query.FindNearestPath].
type VectorQuery struct {
q Query
}

// FindNearest returns a query that can perform vector distance (similarity) search with given parameters.
// FindNearest returns a query that can perform vector distance (similarity) search.
//
// The returned query, when executed, performs a distance (similarity) search on the specified
// The returned query, when executed, performs a distance search on the specified
// vectorField against the given queryVector and returns the top documents that are closest
// to the queryVector;.
// to the queryVector according to measure. At most limit documents are returned.
//
// Only documents whose vectorField field is a Vector of the same dimension as queryVector
// participate in the query, all other documents are ignored.
// Only documents whose vectorField field is a Vector32 or Vector64 of the same dimension
// as queryVector participate in the query; all other documents are ignored.
// In particular, fields of type []float32 or []float64 are ignored.
//
// The vectorField argument can be a single field or a dot-separated sequence of
// fields, and must not contain any of the runes "˜*/[]".
//
// The queryVector argument can be any of the following types:
// - []float32
// - []float64
// - Vector32
// - Vector64
func (q Query) FindNearest(vectorField string, queryVector any, limit int, measure DistanceMeasure, options *FindNearestOptions) VectorQuery {
// Validate field path
fieldPath, err := parseDotSeparatedString(vectorField)
if err != nil {
q.err = err
return VectorQuery{
q: q,
}
return VectorQuery{q: q}
}
return q.FindNearestPath(fieldPath, queryVector, limit, measure, options)
}
Expand All @@ -429,11 +435,9 @@ func (vq VectorQuery) Documents(ctx context.Context) *DocumentIterator {
return vq.q.Documents(ctx)
}

// FindNearestPath is similar to FindNearest but it accepts a [FieldPath].
// FindNearestPath is like [Query.FindNearest] but it accepts a [FieldPath].
func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector any, limit int, measure DistanceMeasure, options *FindNearestOptions) VectorQuery {
vq := VectorQuery{
q: q,
}
vq := VectorQuery{q: q}

// Convert field path to field reference
vectorFieldRef, err := fref(vectorFieldPath)
Expand Down
1 change: 1 addition & 0 deletions firestore/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ func TestQueryToProto(t *testing.T) {

// Convert a Query to a Proto and back again verifying roundtripping
func TestQueryFromProtoRoundTrip(t *testing.T) {
t.Skip("flaky due to random map order iteration")
c := &Client{projectID: "P", databaseID: "DB"}

for _, test := range createTestScenarios(t) {
Expand Down
30 changes: 8 additions & 22 deletions firestore/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ type Vector64 []float64
type Vector32 []float32

// vectorToProtoValue returns a Firestore [pb.Value] representing the Vector.
// The calling function should check for type safety
func vectorToProtoValue[vType float32 | float64](v []vType) *pb.Value {
func vectorToProtoValue[T float32 | float64](v []T) *pb.Value {
if v == nil {
return nullValue
}
Expand All @@ -59,40 +58,27 @@ func vectorToProtoValue[vType float32 | float64](v []vType) *pb.Value {
}
}

func vectorFromProtoValue(v *pb.Value) (interface{}, error) {
return vector64FromProtoValue(v)
}

func vector32FromProtoValue(v *pb.Value) (Vector32, error) {
pbArrVals, err := pbValToVectorVals(v)
if err != nil {
return nil, err
}

floats := make([]float32, len(pbArrVals))
for i, fval := range pbArrVals {
dv, ok := fval.ValueType.(*pb.Value_DoubleValue)
if !ok {
return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_DoubleValue", fval.ValueType)
}
floats[i] = float32(dv.DoubleValue)
}
return floats, nil
return vectorFromProtoValue[float32](v)
}

func vector64FromProtoValue(v *pb.Value) (Vector64, error) {
return vectorFromProtoValue[float64](v)
}

func vectorFromProtoValue[T float32 | float64](v *pb.Value) ([]T, error) {
pbArrVals, err := pbValToVectorVals(v)
if err != nil {
return nil, err
}

floats := make([]float64, len(pbArrVals))
floats := make([]T, len(pbArrVals))
for i, fval := range pbArrVals {
dv, ok := fval.ValueType.(*pb.Value_DoubleValue)
if !ok {
return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_DoubleValue", fval.ValueType)
}
floats[i] = dv.DoubleValue
floats[i] = T(dv.DoubleValue)
}
return floats, nil
}
Expand Down
2 changes: 1 addition & 1 deletion firestore/vector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func TestVectorFromProtoValue(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := vectorFromProtoValue(tt.v)
got, err := vector64FromProtoValue(tt.v)
if (err != nil) != tt.wantErr {
t.Errorf("vectorFromProtoValue() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
Loading