Skip to content

Commit

Permalink
CSHARP-4768: Introduce $vectorSearch aggregation stage (#1187)
Browse files Browse the repository at this point in the history
  • Loading branch information
BorisDog authored Oct 6, 2023
1 parent e9d6231 commit f7e5a31
Show file tree
Hide file tree
Showing 14 changed files with 749 additions and 181 deletions.
28 changes: 28 additions & 0 deletions src/MongoDB.Driver.Core/Core/Misc/Ensure.cs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,34 @@ public static string IsNotNullOrEmpty(string value, string paramName)
return value;
}

/// <summary>
/// Ensures that the value of a parameter is not null or empty.
/// </summary>
/// <param name="value">The value of the parameter.</param>
/// <param name="paramName">The name of the parameter.</param>
/// <returns>The value of the parameter.</returns>
public static IEnumerable<T> IsNotNullOrEmpty<T>(IEnumerable<T> value, string paramName)
{
if (value == null)
{
throw new ArgumentNullException(paramName);
}

if (value is ICollection<T> collection)
{
if (collection.Count == 0)
{
throw new ArgumentException("Value cannot be empty.", paramName);
}
}
else if (!value.Any())
{
throw new ArgumentException("Value cannot be empty.", paramName);
}

return value;
}

/// <summary>
/// Ensures that the value of a parameter is null.
/// </summary>
Expand Down
9 changes: 9 additions & 0 deletions src/MongoDB.Driver/AggregateFluent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,15 @@ public override IAggregateFluent<TNewResult> Unwind<TNewResult>(FieldDefinition<
return WithPipeline(_pipeline.Unwind(field, options));
}

public override IAggregateFluent<TResult> VectorSearch(
FieldDefinition<TResult> field,
QueryVector queryVector,
int limit,
VectorSearchOptions<TResult> options = null)
{
return WithPipeline(_pipeline.VectorSearch(field, queryVector, limit, options));
}

public override string ToString()
{
var linqProvider = Database.Client.Settings.LinqProvider;
Expand Down
10 changes: 10 additions & 0 deletions src/MongoDB.Driver/AggregateFluentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,16 @@ public virtual IAggregateFluent<TNewResult> Unwind<TNewResult>(FieldDefinition<T
throw new NotImplementedException();
}

/// <inheritdoc />
public virtual IAggregateFluent<TResult> VectorSearch(
FieldDefinition<TResult> field,
QueryVector queryVector,
int limit,
VectorSearchOptions<TResult> options = null)
{
throw new NotImplementedException();
}

/// <inheritdoc />
public virtual void ToCollection(CancellationToken cancellationToken)
{
Expand Down
14 changes: 14 additions & 0 deletions src/MongoDB.Driver/IAggregateFluent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,20 @@ IAggregateFluent<TResult> UnionWith<TWith>(
/// <param name="options">The options.</param>
/// <returns>The fluent aggregate interface.</returns>
IAggregateFluent<TNewResult> Unwind<TNewResult>(FieldDefinition<TResult> field, AggregateUnwindOptions<TNewResult> options = null);

/// <summary>
/// Appends a vector search stage.
/// </summary>
/// <param name="field">The field.</param>
/// <param name="queryVector">The query vector.</param>
/// <param name="limit">The limit.</param>
/// <param name="options">The vector search options.</param>
/// <returns>The fluent aggregate interface.</returns>
IAggregateFluent<TResult> VectorSearch(
FieldDefinition<TResult> field,
QueryVector queryVector,
int limit,
VectorSearchOptions<TResult> options = null);
}

/// <summary>
Expand Down
22 changes: 22 additions & 0 deletions src/MongoDB.Driver/IAggregateFluentExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -967,5 +967,27 @@ public static IAggregateFluent<TNewResult> Unwind<TResult, TNewResult>(this IAgg

return IAsyncCursorSourceExtensions.SingleOrDefaultAsync(aggregate.Limit(2), cancellationToken);
}

/// <summary>
/// Appends a $vectorSearch stage.
/// </summary>
/// <typeparam name="TResult">The type of the result.</typeparam>
/// <param name="aggregate">The aggregate.</param>
/// <param name="field">The field.</param>
/// <param name="queryVector">The query vector.</param>
/// <param name="limit">The limit.</param>
/// <param name="options">The vector search options.</param>
/// <returns>The fluent aggregate interface.</returns>
public static IAggregateFluent<TResult> VectorSearch<TResult>(
this IAggregateFluent<TResult> aggregate,
Expression<Func<TResult, object>> field,
QueryVector queryVector,
int limit,
VectorSearchOptions<TResult> options = null)
{
Ensure.IsNotNull(aggregate, nameof(aggregate));

return aggregate.VectorSearch(new ExpressionFieldDefinition<TResult>(field), queryVector, limit, options);
}
}
}
50 changes: 50 additions & 0 deletions src/MongoDB.Driver/Linq/MongoQueryable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3484,6 +3484,56 @@ public static IOrderedMongoQueryable<TSource> ThenByDescending<TSource, TKey>(th
return (IOrderedMongoQueryable<TSource>)Queryable.ThenByDescending(source, keySelector);
}

/// <summary>
/// Appends a $vectorSearch stage to the LINQ pipeline.
/// </summary>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <typeparam name="TField">The type of the field.</typeparam>
/// <param name="source">A sequence of values.</param>
/// <param name="field">The field.</param>
/// <param name="queryVector">The query vector.</param>
/// <param name="limit">The limit.</param>
/// <param name="options">The options.</param>
/// <returns>
/// The queryable with a new stage appended.
/// </returns>
public static IMongoQueryable<TSource> VectorSearch<TSource, TField>(
this IMongoQueryable<TSource> source,
FieldDefinition<TSource> field,
QueryVector queryVector,
int limit,
VectorSearchOptions<TSource> options = null)
{
return AppendStage(
source,
PipelineStageDefinitionBuilder.VectorSearch(field, queryVector, limit, options));
}

/// <summary>
/// Appends a $vectorSearch stage to the LINQ pipeline.
/// </summary>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <typeparam name="TField">The type of the field.</typeparam>
/// <param name="source">A sequence of values.</param>
/// <param name="field">The field.</param>
/// <param name="queryVector">The query vector.</param>
/// <param name="limit">The limit.</param>
/// <param name="options">The options.</param>
/// <returns>
/// The queryable with a new stage appended.
/// </returns>
public static IMongoQueryable<TSource> VectorSearch<TSource, TField>(
this IMongoQueryable<TSource> source,
Expression<Func<TSource, TField>> field,
QueryVector queryVector,
int limit,
VectorSearchOptions<TSource> options = null)
{
return AppendStage(
source,
PipelineStageDefinitionBuilder.VectorSearch(field, queryVector, limit, options));
}

/// <summary>
/// Filters a sequence of values based on a predicate.
/// </summary>
Expand Down
Loading

0 comments on commit f7e5a31

Please sign in to comment.