Skip to content

Commit

Permalink
Implement vectorSearch for Aggregation pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill5k committed Jul 7, 2024
1 parent 7665176 commit b97bf20
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.bson.conversions.Bson
import com.mongodb.client.model.densify.{DensifyOptions, DensifyRange}
import com.mongodb.client.model.fill.{FillOptions, FillOutputField}
import com.mongodb.client.model.geojson.Point
import com.mongodb.client.model.search.{SearchCollector, SearchOperator, SearchOptions}
import com.mongodb.client.model.search.{FieldSearchPath, SearchCollector, SearchOperator, SearchOptions, VectorSearchOptions}

trait Aggregate extends AsJava {

Expand Down Expand Up @@ -285,6 +285,34 @@ trait Aggregate extends AsJava {
options: GraphLookupOptions = new GraphLookupOptions()
): Aggregate

/** Creates a \$vectorSearch pipeline stage supported by MongoDB Atlas. You may use the \$meta: "vectorSearchScore" expression, e.g., via
* Projection.metaVectorSearchScore(String), to extract the relevance score assigned to each found document.
*
* @param queryVector
* The query vector. The number of dimensions must match that of the index.
* @param path
* The field to be searched.
* @param index
* The name of the index to use.
* @param numCandidates
* The number of candidates.
* @param limit
* The limit on the number of documents produced by the pipeline stage.
* @param options
* Optional \$vectorSearch pipeline stage fields.
* @return
* The Aggregate with \$vectorSearch pipeline stage [[https://docs.mongodb.com/manual/reference/operator/aggregation/vectorSearch/]]
* @since 4.11
*/
def vectorSearch(
path: FieldSearchPath,
queryVector: Seq[Double],
index: String,
numCandidates: Long,
limit: Long,
options: VectorSearchOptions = VectorSearchOptions.vectorSearchOptions()
): Aggregate

/** Creates a facet pipeline stage.
*
* @param facets
Expand Down Expand Up @@ -497,6 +525,15 @@ object Aggregate {
options: GraphLookupOptions = new GraphLookupOptions()
): Aggregate = empty.graphLookup(from, startWith, connectFromField, connectToField, as, options)

def vectorSearch(
path: FieldSearchPath,
queryVector: Seq[Double],
index: String,
numCandidates: Long,
limit: Long,
options: VectorSearchOptions = VectorSearchOptions.vectorSearchOptions()
): Aggregate = empty.vectorSearch(path, queryVector, index, numCandidates, limit, options)

def unionWith(collection: String, pipeline: Aggregate): Aggregate = empty.unionWith(collection, pipeline)
}

Expand Down Expand Up @@ -587,6 +624,17 @@ final private case class AggregateBuilder(
options: GraphLookupOptions = new GraphLookupOptions()
): Aggregate = AggregateBuilder(Aggregates.graphLookup(from, startWith, connectFromField, connectToField, as, options) :: aggregates)

def vectorSearch(
path: FieldSearchPath,
queryVector: Seq[Double],
index: String,
numCandidates: Long,
limit: Long,
options: VectorSearchOptions = VectorSearchOptions.vectorSearchOptions()
): Aggregate = AggregateBuilder(
Aggregates.vectorSearch(path, asJava(queryVector.map(java.lang.Double.valueOf)), index, numCandidates, limit, options) :: aggregates
)

def facet(facets: List[Aggregate.Facet]): Aggregate =
AggregateBuilder(Aggregates.facet(asJava(facets.map(_.toJava))) :: aggregates)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,18 @@ trait Projection {
*/
def slice(fieldName: String, skip: Int, limit: Int): Projection

/** Creates a projection to the given field name of the vectorSearchScore, for use with
* Aggregate.vectorSearch(FieldSearchPath,Seq,String,Long,Long,VectorSearchOptions). Calling this method is equivalent to calling
* meta(String,String) with "vectorSearchScore" as the second argument.
*
* @param fieldName
* the field name
* @return
* the projection
* @since 4.11
*/
def metaVectorSearchScore(fieldName: String): Projection

/** Merges 2 sequences of projection operations together. If there are duplicate keys, the last one takes precedence.
*
* @param anotherProjection
Expand Down Expand Up @@ -212,6 +224,7 @@ object Projection extends Projection {
def metaTextScore(fieldName: String): Projection = empty.metaTextScore(fieldName)
def metaSearchScore(fieldName: String): Projection = empty.metaSearchScore(fieldName)
def metaSearchHighlights(fieldName: String): Projection = empty.metaSearchHighlights(fieldName)
def metaVectorSearchScore(fieldName: String): Projection = empty.metaVectorSearchScore(fieldName)
def slice(fieldName: String, limit: Int): Projection = empty.slice(fieldName, limit)
def slice(fieldName: String, skip: Int, limit: Int): Projection = empty.slice(fieldName, skip, limit)
def combinedWith(anotherProjection: Projection): Projection = empty.combinedWith(anotherProjection)
Expand Down Expand Up @@ -263,6 +276,9 @@ final private case class ProjectionBuilder(
override def metaSearchHighlights(fieldName: String): Projection =
ProjectionBuilder(Projections.metaSearchHighlights(fieldName) :: projections)

def metaVectorSearchScore(fieldName: String): Projection =
ProjectionBuilder(Projections.metaVectorSearchScore(fieldName) :: projections)

override def computedSearchMeta(fieldName: String): Projection =
ProjectionBuilder(Projections.computedSearchMeta(fieldName) :: projections)

Expand Down

0 comments on commit b97bf20

Please sign in to comment.