Skip to content

Commit

Permalink
API layer changes for vector distance.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkDuckworth committed Jul 26, 2024
1 parent f73e28b commit 72ca7c0
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 21 deletions.
93 changes: 82 additions & 11 deletions dev/src/reference/query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,9 @@ export class Query<
* @param options - Options control the vector query. `limit` specifies the upper bound of documents to return, must
* be a positive integer with a maximum value of 1000. `distanceMeasure` specifies what type of distance is calculated
* when performing the query.
*
* @deprecated Use the new {@link findNearest} implementation
* accepting `limit` and `distanceMeasure` as independent arguments.
*/
findNearest(
vectorField: string | firestore.FieldPath,
Expand All @@ -636,29 +639,97 @@ export class Query<
limit: number;
distanceMeasure: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT';
}
): VectorQuery<AppModelType, DbModelType>;

/**
* Returns a query that can perform vector distance (similarity) search with given parameters.
*
* The returned query, when executed, performs a distance (similarity) search on the specified
* `vectorField` against the given `queryVector` and returns the top documents that are closest
* to the `queryVector`.
*
* Only documents whose `vectorField` field is a {@link VectorValue} of the same dimension as `queryVector`
* participate in the query, all other documents are ignored.
*
* @example
* ```
* // Returns the closest 10 documents whose Euclidean distance from their 'embedding' fields are closed to [41, 42].
* const vectorQuery = col.findNearest('embedding', [41, 42], {limit: 10, distanceMeasure: 'EUCLIDEAN'});
*
* const querySnapshot = await aggregateQuery.get();
* querySnapshot.forEach(...);
* ```
*
* @param vectorField - A string or {@link FieldPath} specifying the vector field to search on.
* @param queryVector - The {@link VectorValue} used to measure the distance from `vectorField` values in the documents.
* @param options - Options control the vector query. `limit` specifies the upper bound of documents to return, must
* be a positive integer with a maximum value of 1000. `distanceMeasure` specifies what type of distance is calculated
* when performing the query.
*/
findNearest(
vectorField: string | firestore.FieldPath,
queryVector: firestore.VectorValue | Array<number>,
limit: number,
distanceMeasure: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT',
options?: {
distanceResultField?: string | firestore.FieldPath;
distanceThreshold?: number;
}): VectorQuery<AppModelType, DbModelType>;

findNearest(
vectorField: string | firestore.FieldPath,
queryVector: firestore.VectorValue | Array<number>,
limitOrOptions: number |{
limit?: number;
distanceMeasure?: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT';
},
distanceMeasure?: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT',
options?: {
distanceResultField?: string | firestore.FieldPath;
distanceThreshold?: number;
}
): VectorQuery<AppModelType, DbModelType> {
if (typeof limitOrOptions == 'number') {
return this._findNearest(vectorField, queryVector, limitOrOptions, distanceMeasure!, options);
} else {
return this._findNearest(vectorField, queryVector, limitOrOptions!.limit!, limitOrOptions!.distanceMeasure!);
}
}

_findNearest(
vectorField: string | firestore.FieldPath,
queryVector: firestore.VectorValue | Array<number>,
limit: number,
distanceMeasure: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT',
options?: {
distanceResultField?: string | firestore.FieldPath;
distanceThreshold?: number;
}
): VectorQuery<AppModelType, DbModelType> {
validateFieldPath('vectorField', vectorField);

if (options.limit <= 0) {
throw invalidArgumentMessage('options.limit', 'positive limit number');
if (limit <= 0) {
throw invalidArgumentMessage('limit', 'positive limit number');
}

if (
(Array.isArray(queryVector)
? queryVector.length
: queryVector.toArray().length) === 0
(Array.isArray(queryVector)
? queryVector.length
: queryVector.toArray().length) === 0
) {
throw invalidArgumentMessage(
'queryVector',
'vector size must be larger than 0'
'queryVector',
'vector size must be larger than 0'
);
}

return new VectorQuery<AppModelType, DbModelType>(
this,
vectorField,
queryVector,
new VectorQueryOptions(options.limit, options.distanceMeasure)
this,
vectorField,
queryVector,
limit,
distanceMeasure,
new VectorQueryOptions(options?.distanceResultField, options?.distanceThreshold)
);
}

Expand Down
26 changes: 19 additions & 7 deletions dev/src/reference/vector-query-options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import * as firestore from '@google-cloud/firestore';
import {FieldPath} from "../path";

export class VectorQueryOptions {
readonly distanceResultField?: firestore.FieldPath;

constructor(
readonly limit: number,
readonly distanceMeasure: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT'
) {}
distanceResultField?: string | firestore.FieldPath,
readonly distanceThreshold?: number
) {
if (typeof distanceResultField == 'string') {
this.distanceResultField = new FieldPath(distanceResultField);
}
}

isEqual(other: VectorQueryOptions): boolean {
if (this === other) {
Expand All @@ -28,9 +36,13 @@ export class VectorQueryOptions {
return false;
}

return (
this.limit === other.limit &&
this.distanceMeasure === other.distanceMeasure
);
let distanceResultFieldEqual = false;
if (typeof other.distanceResultField == 'undefined') {
distanceResultFieldEqual = (typeof this.distanceResultField == 'undefined');
} else {
distanceResultFieldEqual = (this.distanceResultField?.isEqual(other.distanceResultField) == true);
}

return this.distanceThreshold === other.distanceThreshold && distanceResultFieldEqual;
}
}
8 changes: 5 additions & 3 deletions dev/src/reference/vector-query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ export class VectorQuery<
private readonly _query: Query<AppModelType, DbModelType>,
private readonly vectorField: string | firestore.FieldPath,
private readonly queryVector: firestore.VectorValue | Array<number>,
private readonly limit: number,
private readonly distanceMeasure: 'EUCLIDEAN' | 'COSINE' | 'DOT_PRODUCT',
private readonly options: VectorQueryOptions
) {
this._queryUtil = new QueryUtil<
Expand Down Expand Up @@ -157,7 +159,7 @@ export class VectorQuery<
}

/**
* Internal method for serializing a query to its RunAggregationQuery proto
* Internal method for serializing a query to its proto
* representation with an optional transaction id.
*
* @private
Expand All @@ -175,8 +177,8 @@ export class VectorQuery<
: (this.queryVector as VectorValue);

queryProto.structuredQuery!.findNearest = {
limit: {value: this.options.limit},
distanceMeasure: this.options.distanceMeasure,
limit: {value: this.limit},
distanceMeasure: this.distanceMeasure,
vectorField: {
fieldPath: FieldPath.fromArgument(this.vectorField).formattedName,
},
Expand Down

0 comments on commit 72ca7c0

Please sign in to comment.