From 4ddebfbbba1a57557bf0aae8fe1e572d7e804d52 Mon Sep 17 00:00:00 2001 From: rstam Date: Wed, 4 Oct 2023 10:53:20 -0700 Subject: [PATCH] CSHARP-4804: Slice projection must be rendered differently for Find and Aggregate. --- src/MongoDB.Driver/ProjectionDefinition.cs | 18 +- .../ProjectionDefinitionBuilder.cs | 130 ++++++++++++- .../Jira/CSharp4804Tests.cs | 181 ++++++++++++++++++ 3 files changed, 321 insertions(+), 8 deletions(-) create mode 100644 tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4804Tests.cs diff --git a/src/MongoDB.Driver/ProjectionDefinition.cs b/src/MongoDB.Driver/ProjectionDefinition.cs index 7f1431452f1..121621196d8 100644 --- a/src/MongoDB.Driver/ProjectionDefinition.cs +++ b/src/MongoDB.Driver/ProjectionDefinition.cs @@ -85,6 +85,9 @@ public virtual BsonDocument Render(IBsonSerializer sourceSerializer, IB /// A . public abstract BsonDocument Render(IBsonSerializer sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider); + internal virtual BsonDocument RenderForFind(IBsonSerializer sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider) + => Render(sourceSerializer, serializerRegistry, linqProvider); + /// /// Performs an implicit conversion from to . /// @@ -489,7 +492,20 @@ public IBsonSerializer ResultSerializer public override RenderedProjectionDefinition Render(IBsonSerializer sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider) { - var document = _projection.Render(sourceSerializer, serializerRegistry, linqProvider); + return Render(sourceSerializer, serializerRegistry, projection => projection.Render(sourceSerializer, serializerRegistry, linqProvider)); + } + + internal override RenderedProjectionDefinition RenderForFind(IBsonSerializer sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider) + { + return Render(sourceSerializer, serializerRegistry, projection => projection.RenderForFind(sourceSerializer, serializerRegistry, linqProvider)); + } + + private RenderedProjectionDefinition Render( + IBsonSerializer sourceSerializer, + IBsonSerializerRegistry serializerRegistry, + Func, BsonDocument> renderer) + { + var document = renderer(_projection); return new RenderedProjectionDefinition( document, _projectionSerializer ?? (sourceSerializer as IBsonSerializer) ?? serializerRegistry.GetSerializer()); diff --git a/src/MongoDB.Driver/ProjectionDefinitionBuilder.cs b/src/MongoDB.Driver/ProjectionDefinitionBuilder.cs index 8fdd6d140e5..6dea8371464 100644 --- a/src/MongoDB.Driver/ProjectionDefinitionBuilder.cs +++ b/src/MongoDB.Driver/ProjectionDefinitionBuilder.cs @@ -256,6 +256,22 @@ public static ProjectionDefinition SearchMeta( return builder.Combine(projection, builder.SearchMeta(field)); } + /// + /// Combines an existing projection with an array slice projection. + /// + /// The type of the document. + /// The projection. + /// The field. + /// The limit. + /// + /// A combined projection. + /// + public static ProjectionDefinition Slice(this ProjectionDefinition projection, FieldDefinition field, int limit) + { + var builder = Builders.Projection; + return builder.Combine(projection, builder.Slice(field, limit)); + } + /// /// Combines an existing projection with an array slice projection. /// @@ -267,12 +283,28 @@ public static ProjectionDefinition SearchMeta( /// /// A combined projection. /// - public static ProjectionDefinition Slice(this ProjectionDefinition projection, FieldDefinition field, int skip, int? limit = null) + public static ProjectionDefinition Slice(this ProjectionDefinition projection, FieldDefinition field, int skip, int limit) { var builder = Builders.Projection; return builder.Combine(projection, builder.Slice(field, skip, limit)); } + /// + /// Combines an existing projection with an array slice projection. + /// + /// The type of the document. + /// The projection. + /// The field. + /// The limit. + /// + /// A combined projection. + /// + public static ProjectionDefinition Slice(this ProjectionDefinition projection, Expression> field, int limit) + { + var builder = Builders.Projection; + return builder.Combine(projection, builder.Slice(field, limit)); + } + /// /// Combines an existing projection with an array slice projection. /// @@ -284,7 +316,7 @@ public static ProjectionDefinition Slice(this ProjectionDe /// /// A combined projection. /// - public static ProjectionDefinition Slice(this ProjectionDefinition projection, Expression> field, int skip, int? limit = null) + public static ProjectionDefinition Slice(this ProjectionDefinition projection, Expression> field, int skip, int limit) { var builder = Builders.Projection; return builder.Combine(projection, builder.Slice(field, skip, limit)); @@ -520,6 +552,19 @@ public ProjectionDefinition SearchMeta(Expression return SearchMeta(new ExpressionFieldDefinition(field)); } + /// + /// Creates an array slice projection. + /// + /// The field. + /// The limit. + /// + /// An array slice projection. + /// + public ProjectionDefinition Slice(FieldDefinition field, int limit) + { + return new SliceProjectionDefinition(field, limit); + } + /// /// Creates an array slice projection. /// @@ -529,10 +574,22 @@ public ProjectionDefinition SearchMeta(Expression /// /// An array slice projection. /// - public ProjectionDefinition Slice(FieldDefinition field, int skip, int? limit = null) + public ProjectionDefinition Slice(FieldDefinition field, int skip, int limit) { - var value = limit.HasValue ? (BsonValue)new BsonArray { skip, limit.Value } : skip; - return new SingleFieldProjectionDefinition(field, new BsonDocument("$slice", value)); + return new SliceProjectionDefinition(field, skip, limit); + } + + /// + /// Creates an array slice projection. + /// + /// The field. + /// The limit. + /// + /// An array slice projection. + /// + public ProjectionDefinition Slice(Expression> field, int limit) + { + return Slice(new ExpressionFieldDefinition(field), limit); } /// @@ -544,7 +601,7 @@ public ProjectionDefinition Slice(FieldDefinition field, int s /// /// An array slice projection. /// - public ProjectionDefinition Slice(Expression> field, int skip, int? limit = null) + public ProjectionDefinition Slice(Expression> field, int skip, int limit) { return Slice(new ExpressionFieldDefinition(field), skip, limit); } @@ -560,12 +617,22 @@ public CombinedProjectionDefinition(IEnumerable> p } public override BsonDocument Render(IBsonSerializer sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider) + { + return Render(projection => projection.Render(sourceSerializer, serializerRegistry, linqProvider)); + } + + internal override BsonDocument RenderForFind(IBsonSerializer sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider) + { + return Render(projection => projection.RenderForFind(sourceSerializer, serializerRegistry, linqProvider)); + } + + private BsonDocument Render(Func, BsonDocument> renderer) { var document = new BsonDocument(); foreach (var projection in _projections) { - var renderedProjection = projection.Render(sourceSerializer, serializerRegistry, linqProvider); + var renderedProjection = renderer(projection); foreach (var element in renderedProjection.Elements) { @@ -650,4 +717,53 @@ public override BsonDocument Render(IBsonSerializer sourceSerializer, I return new BsonDocument(renderedField.FieldName, _value); } } + + internal sealed class SliceProjectionDefinition : ProjectionDefinition + { + private readonly FieldDefinition _field; + private readonly BsonValue _limit; + private readonly BsonValue _skip; + + public SliceProjectionDefinition(FieldDefinition field, BsonValue limit) + { + _field = Ensure.IsNotNull(field, nameof(field)); + _limit = Ensure.IsNotNull(limit, nameof(limit)); + } + + public SliceProjectionDefinition(FieldDefinition field, BsonValue skip, BsonValue limit) + { + _field = Ensure.IsNotNull(field, nameof(field)); + _skip = skip; // can be null + _limit = Ensure.IsNotNull(limit, nameof(limit)); + } + + public override BsonDocument Render(IBsonSerializer sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider) + { + return Render(sourceSerializer, serializerRegistry, linqProvider, RenderArgs); + } + + internal override BsonDocument RenderForFind(IBsonSerializer sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider) + { + return Render(sourceSerializer, serializerRegistry, linqProvider, RenderArgsForFind); + } + + private BsonDocument Render(IBsonSerializer sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider, Func argsRenderer) + { + var renderedField = _field.Render(sourceSerializer, serializerRegistry, linqProvider); + var sliceArgs = argsRenderer(renderedField.FieldName); + return new BsonDocument(renderedField.FieldName, new BsonDocument("$slice", sliceArgs)); + } + + private BsonValue RenderArgs(string fieldName) + { + return _skip == null ? + new BsonArray { "$" + fieldName, _limit } : + new BsonArray { "$" + fieldName, _skip, _limit }; + } + + private BsonValue RenderArgsForFind(string fieldName) + { + return _skip == null ? _limit : new BsonArray { _skip, _limit }; + } + } } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4804Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4804Tests.cs new file mode 100644 index 00000000000..97534a76038 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4804Tests.cs @@ -0,0 +1,181 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Linq; +using FluentAssertions; +using MongoDB.Driver.Linq; +using MongoDB.TestHelpers.XunitExtensions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp4804Tests : Linq3IntegrationTest + { + [Theory] + [ParameterAttributeData] + public void Find_Slice_with_field_name_and_limit_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + var projection = Builders.Projection.Slice("A", 3); + + var find = collection.Find("{}").Project(projection); + + var translatedProjection = TranslateFindProjection(collection, find); + translatedProjection.Should().Be("{ A : { $slice : 3 } }"); + + var result = find.Single(); + result["A"].AsBsonArray.Select(i => i.AsInt32).Should().Equal(1, 2, 3); + } + + [Theory] + [ParameterAttributeData] + public void Find_Slice_with_field_expression_and_limit_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + var projection = Builders.Projection.Slice(x => x.A, 3); + + var find = collection.Find("{}").Project(projection); + + var translatedProjection = TranslateFindProjection(collection, find); + translatedProjection.Should().Be("{ A : { $slice : 3 } }"); + + var result = find.Single(); + result["A"].AsBsonArray.Select(i => i.AsInt32).Should().Equal(1, 2, 3); + } + + [Theory] + [ParameterAttributeData] + public void Find_Slice_with_field_name_and_skip_and_limit_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + var projection = Builders.Projection.Slice("A", 1, 3); + + var find = collection.Find("{}").Project(projection); + + var translatedProjection = TranslateFindProjection(collection, find); + translatedProjection.Should().Be("{ A : { $slice : [1, 3] } }"); + + var result = find.Single(); + result["A"].AsBsonArray.Select(i => i.AsInt32).Should().Equal(2, 3, 4); + } + + [Theory] + [ParameterAttributeData] + public void Find_Slice_with_field_expression_and_skip_and_limit_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + var projection = Builders.Projection.Slice(x => x.A, 1, 3); + + var find = collection.Find("{}").Project(projection); + + var translatedProjection = TranslateFindProjection(collection, find); + translatedProjection.Should().Be("{ A : { $slice : [1, 3] } }"); + + var result = find.Single(); + result["A"].AsBsonArray.Select(i => i.AsInt32).Should().Equal(2, 3, 4); + } + + [Theory] + [ParameterAttributeData] + public void Aggregate_Slice_with_field_name_and_limit_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + var projection = Builders.Projection.Slice("A", 3); + + var aggregate = collection.Aggregate() + .Project(projection); + + var stages = Translate(collection, aggregate); + AssertStages(stages, "{ $project : { A : { $slice : ['$A', 3] } } }"); + + var result = aggregate.Single(); + result["A"].AsBsonArray.Select(i => i.AsInt32).Should().Equal(1, 2, 3); + } + + [Theory] + [ParameterAttributeData] + public void Aggregate_Slice_with_field_expression_and_limit_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + var projection = Builders.Projection.Slice(x => x.A, 3); + + var aggregate = collection.Aggregate() + .Project(projection); + + var stages = Translate(collection, aggregate); + AssertStages(stages, "{ $project : { A : { $slice : ['$A', 3] } } }"); + + var result = aggregate.Single(); + result["A"].AsBsonArray.Select(i => i.AsInt32).Should().Equal(1, 2, 3); + } + + [Theory] + [ParameterAttributeData] + public void AggregateSlice_with_field_name_and_skip_and_limit_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + var projection = Builders.Projection.Slice("A", 1, 3); + + var aggregate = collection.Aggregate() + .Project(projection); + + var stages = Translate(collection, aggregate); + AssertStages(stages, "{ $project : { A : { $slice : ['$A', 1, 3] } } }"); + + var result = aggregate.Single(); + result["A"].AsBsonArray.Select(i => i.AsInt32).Should().Equal(2, 3, 4); + } + + [Theory] + [ParameterAttributeData] + public void Aggregate_Slice_with_field_expression_and_skip_and_limit_should_work( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + var projection = Builders.Projection.Slice(x => x.A, 1, 3); + + var aggregate = collection.Aggregate() + .Project(projection); + + var stages = Translate(collection, aggregate); + AssertStages(stages, "{ $project : { A : { $slice : ['$A', 1, 3] } } }"); + + var result = aggregate.Single(); + result["A"].AsBsonArray.Select(i => i.AsInt32).Should().Equal(2, 3, 4); + } + + private IMongoCollection GetCollection(LinqProvider linqProvider) + { + var collection = GetCollection("test", linqProvider); + CreateCollection( + collection, + new C { Id = 1, A = new[] { 1, 2, 3, 4, 5 } }); + return collection; + } + + public class C + { + public int Id { get; set; } + public int[] A { get; set; } + } + } +}