From 61d2c774c958d65a8ed9889e6d0ec9de2f8be0ca Mon Sep 17 00:00:00 2001 From: rstam Date: Fri, 31 May 2024 09:33:26 -0700 Subject: [PATCH] CSHARP-2509: Support Dictionary.ContainsValue in LINQ queries. --- .../Ast/Expressions/AstExpression.cs | 5 + ...essionToAggregationExpressionTranslator.cs | 1 + ...MethodToAggregationExpressionTranslator.cs | 115 +++++++++ .../ContainsValueMethodToFilterTranslator.cs | 84 +++++++ .../MethodCallExpressionToFilterTranslator.cs | 1 + .../Jira/CSharp2509Tests.cs | 220 ++++++++++++++++++ 6 files changed, 426 insertions(+) create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsValueMethodToAggregationExpressionTranslator.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsValueMethodToFilterTranslator.cs create mode 100644 tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2509Tests.cs diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs index 6e8faa2edd8..58040f476c9 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Expressions/AstExpression.cs @@ -575,6 +575,11 @@ public static AstExpression NullaryWindowExpression(AstNullaryWindowOperator @op return new AstNullaryWindowExpression(@operator, window); } + public static AstExpression ObjectToArray(AstExpression arg) + { + return new AstUnaryExpression(AstUnaryOperator.ObjectToArray, arg); + } + public static AstExpression Or(params AstExpression[] args) { Ensure.IsNotNull(args, nameof(args)); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs index 0c57aeddf29..c86dbbbb523 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs @@ -37,6 +37,7 @@ public static AggregationExpression Translate(TranslationContext context, Method case "Concat": return ConcatMethodToAggregationExpressionTranslator.Translate(context, expression); case "Contains": return ContainsMethodToAggregationExpressionTranslator.Translate(context, expression); case "ContainsKey": return ContainsKeyMethodToAggregationExpressionTranslator.Translate(context, expression); + case "ContainsValue": return ContainsValueMethodToAggregationExpressionTranslator.Translate(context, expression); case "CovariancePopulation": return CovariancePopulationMethodToAggregationExpressionTranslator.Translate(context, expression); case "CovarianceSample": return CovarianceSampleMethodToAggregationExpressionTranslator.Translate(context, expression); case "Create": return CreateMethodToAggregationExpressionTranslator.Translate(context, expression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsValueMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsValueMethodToAggregationExpressionTranslator.cs new file mode 100644 index 00000000000..c2437521c6d --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsValueMethodToAggregationExpressionTranslator.cs @@ -0,0 +1,115 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators +{ + internal static class ContainsValueMethodToAggregationExpressionTranslator + { + // public methods + public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression) + { + var method = expression.Method; + var arguments = expression.Arguments; + + if (IsContainsValueMethod(method)) + { + var dictionaryExpression = expression.Object; + var valueExpression = arguments[0]; + + var dictionaryTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, dictionaryExpression); + var dictionarySerializer = GetDictionarySerializer(expression, dictionaryTranslation); + var dictionaryRepresentation = dictionarySerializer.DictionaryRepresentation; + + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); + var (valueBinding, valueAst) = AstExpression.UseVarIfNotSimple("value", valueTranslation.Ast); + + AstExpression ast; + switch (dictionaryRepresentation) + { + case DictionaryRepresentation.Document: + ast = AstExpression.Let( + var: valueBinding, + @in: AstExpression.Reduce( + input: AstExpression.ObjectToArray(dictionaryTranslation.Ast), + initialValue: false, + @in: AstExpression.Cond( + @if: AstExpression.Var("value"), + @then: true, + @else: AstExpression.Eq(AstExpression.GetField(AstExpression.Var("this"), "v"), valueAst)))); + break; + + case DictionaryRepresentation.ArrayOfArrays: + ast = AstExpression.Let( + var: valueBinding, + @in: AstExpression.Reduce( + input: dictionaryTranslation.Ast, + initialValue: false, + @in: AstExpression.Cond( + @if: AstExpression.Var("value"), + @then: true, + @else: AstExpression.Eq(AstExpression.ArrayElemAt(AstExpression.Var("this"), 1), valueAst)))); + break; + + case DictionaryRepresentation.ArrayOfDocuments: + ast = AstExpression.Let( + var: valueBinding, + @in: AstExpression.Reduce( + input: dictionaryTranslation.Ast, + initialValue: false, + @in: AstExpression.Cond( + @if: AstExpression.Var("value"), + @then: true, + @else: AstExpression.Eq(AstExpression.GetField(AstExpression.Var("this"), "v"), valueAst)))); + break; + + default: + throw new ExpressionNotSupportedException(expression, because: $"ContainsValue is not supported when DictionaryRepresentation is: {dictionaryRepresentation}"); + } + + return new AggregationExpression(expression, ast, BooleanSerializer.Instance); + } + + throw new ExpressionNotSupportedException(expression); + } + + private static IBsonDictionarySerializer GetDictionarySerializer(Expression expression, AggregationExpression dictionaryTranslation) + { + if (dictionaryTranslation.Serializer is IBsonDictionarySerializer dictionarySerializer) + { + return dictionarySerializer; + } + + throw new ExpressionNotSupportedException(expression, because: $"class {dictionaryTranslation.Serializer.GetType().FullName} does not implement the IBsonDictionarySerializer interface"); + } + + private static bool IsContainsValueMethod(MethodInfo method) + { + return + !method.IsStatic && + method.IsPublic && + method.ReturnType == typeof(bool) && + method.Name == "ContainsValue" && + method.GetParameters() is var parameters && + parameters.Length == 1; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsValueMethodToFilterTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsValueMethodToFilterTranslator.cs new file mode 100644 index 00000000000..256f320364e --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/ContainsValueMethodToFilterTranslator.cs @@ -0,0 +1,84 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.ToFilterFieldTranslators; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.MethodTranslators +{ + internal static class ContainsValueMethodToFilterTranslator + { + public static AstFilter Translate(TranslationContext context, MethodCallExpression expression) + { + var method = expression.Method; + var arguments = expression.Arguments; + + if (IsContainsValueMethod(method)) + { + var dictionaryExpression = expression.Object; + var valueExpression = arguments[0]; + + var dictionaryField = ExpressionToFilterFieldTranslator.Translate(context, dictionaryExpression); + var dictionarySerializer = GetDictionarySerializer(expression, dictionaryField); + var dictionaryRepresentation = dictionarySerializer.DictionaryRepresentation; + var valueSerializer = dictionarySerializer.ValueSerializer; + + if (valueExpression is ConstantExpression constantValueExpression) + { + var valueField = AstFilter.Field("v", valueSerializer); + var value = constantValueExpression.Value; + var serializedValue = SerializationHelper.SerializeValue(valueSerializer, value); + + switch (dictionaryRepresentation) + { + case DictionaryRepresentation.ArrayOfDocuments: + return AstFilter.ElemMatch(dictionaryField, AstFilter.Eq(valueField, serializedValue)); + + default: + throw new ExpressionNotSupportedException(expression, because: $"ContainsValue is not supported when DictionaryRepresentation is: {dictionaryRepresentation}"); + } + } + } + + throw new ExpressionNotSupportedException(expression); + } + + private static IBsonDictionarySerializer GetDictionarySerializer(Expression expression, AstFilterField field) + { + if (field.Serializer is IBsonDictionarySerializer dictionarySerializer) + { + return dictionarySerializer; + } + + throw new ExpressionNotSupportedException(expression, because: $"class {field.Serializer.GetType().FullName} does not implement the IBsonDictionarySerializer interface"); + } + + private static bool IsContainsValueMethod(MethodInfo method) + { + return + !method.IsStatic && + method.IsPublic && + method.ReturnType == typeof(bool) && + method.Name == "ContainsValue" && + method.GetParameters() is var parameters && + parameters.Length == 1; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/MethodCallExpressionToFilterTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/MethodCallExpressionToFilterTranslator.cs index 9791b572c96..46bec0e63df 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/MethodCallExpressionToFilterTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/MethodCallExpressionToFilterTranslator.cs @@ -26,6 +26,7 @@ public static AstFilter Translate(TranslationContext context, MethodCallExpressi { case "Contains": return ContainsMethodToFilterTranslator.Translate(context, expression); case "ContainsKey": return ContainsKeyMethodToFilterTranslator.Translate(context, expression); + case "ContainsValue": return ContainsValueMethodToFilterTranslator.Translate(context, expression); case "EndsWith": return EndsWithMethodToFilterTranslator.Translate(context, expression); case "Equals": return EqualsMethodToFilterTranslator.Translate(context, expression); case "Exists": return ExistsMethodToFilterTranslator.Translate(context, expression); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2509Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2509Tests.cs new file mode 100644 index 00000000000..5311bb058da --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp2509Tests.cs @@ -0,0 +1,220 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using FluentAssertions; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Driver.Linq; +using MongoDB.TestHelpers.XunitExtensions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp2509Tests : Linq3IntegrationTest + { + [Theory] + [ParameterAttributeData] + public void Where_ContainsValue_should_work_when_representation_is_Dictionary( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Where(x => x.D1.ContainsValue(1)); + + if (linqProvider == LinqProvider.V2) + { + var exception = Record.Exception(() => Translate(collection, queryable)); + exception.Should().BeOfType(); + } + else + { + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { $expr : { $reduce : { input : { $objectToArray : '$D1' }, initialValue : false, in : { $cond : { if : '$$value', then : true, else : { $eq : ['$$this.v', 1] } } } } } } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2); + } + } + + [Theory] + [ParameterAttributeData] + public void Where_ContainsValue_should_work_when_representation_is_ArrayOfArrays( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Where(x => x.D2.ContainsValue(1)); + + if (linqProvider == LinqProvider.V2) + { + var exception = Record.Exception(() => Translate(collection, queryable)); + exception.Should().BeOfType(); + } + else + { + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { $expr : { $reduce : { input : '$D2', initialValue : false, in : { $cond : { if : '$$value', then : true, else : { $eq : [{ $arrayElemAt : ['$$this', 1] }, 1] } } } } } } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2); + } + } + + [Theory] + [ParameterAttributeData] + public void Where_ContainsValue_should_work_when_representation_is_ArrayOfDocuments( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Where(x => x.D3.ContainsValue(1)); + + if (linqProvider == LinqProvider.V2) + { + var exception = Record.Exception(() => Translate(collection, queryable)); + exception.Should().BeOfType(); + } + else + { + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $match : { D3 : { $elemMatch : { v : 1 } } } }"); + + var results = queryable.ToList(); + results.Select(x => x.Id).Should().Equal(1, 2); + } + } + + [Theory] + [ParameterAttributeData] + public void Select_ContainsValue_should_work_when_representation_is_Dictionary( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => x.D1.ContainsValue(1)); + + if (linqProvider == LinqProvider.V2) + { + var exception = Record.Exception(() => Translate(collection, queryable)); + exception.Should().BeOfType(); + } + else + { + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $reduce : { input : { $objectToArray : '$D1' }, initialValue : false, in : { $cond : { if : '$$value', then : true, else : { $eq : ['$$this.v', 1] } } } } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(true, true, false); + } + } + + [Theory] + [ParameterAttributeData] + public void Select_ContainsValue_should_work_when_representation_is_ArrayOfArrays( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => x.D2.ContainsValue(1)); + + if (linqProvider == LinqProvider.V2) + { + var exception = Record.Exception(() => Translate(collection, queryable)); + exception.Should().BeOfType(); + } + else + { + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $reduce : { input : '$D2', initialValue : false, in : { $cond : { if : '$$value', then : true, else : { $eq : [{ $arrayElemAt : ['$$this', 1] }, 1] } } } } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(true, true, false); + } + } + + [Theory] + [ParameterAttributeData] + public void Select_ContainsValue_should_work_when_representation_is_ArrayOfDocuments( + [Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider) + { + var collection = GetCollection(linqProvider); + + var queryable = collection.AsQueryable() + .Select(x => x.D3.ContainsValue(1)); + + if (linqProvider == LinqProvider.V2) + { + var exception = Record.Exception(() => Translate(collection, queryable)); + exception.Should().BeOfType(); + } + else + { + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $reduce : { input : '$D3', initialValue : false, in : { $cond : { if : '$$value', then : true, else : { $eq : ['$$this.v', 1] } } } } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(true, true, false); + } + } + + private IMongoCollection GetCollection(LinqProvider linqProvider) + { + var collection = GetCollection("test", linqProvider); + CreateCollection( + collection, + new User + { + Id = 1, + D1 = new() { { "A", 1 }, { "B", 2 } }, + D2 = new() { { "A", 1 }, { "B", 2 } }, + D3 = new() { { "A", 1 }, { "B", 2 } } + }, + new User + { + Id = 2, + D1 = new() { { "A", 2 }, { "B", 1 } }, + D2 = new() { { "A", 2 }, { "B", 1 } }, + D3 = new() { { "A", 2 }, { "B", 1 } } + }, + new User + { + Id = 3, + D1 = new() { { "A", 2 }, { "B", 3 } }, + D2 = new() { { "A", 2 }, { "B", 3 } }, + D3 = new() { { "A", 2 }, { "B", 3 } } + }); + return collection; + } + + private class User + { + public int Id { get; set; } + [BsonDictionaryOptions(DictionaryRepresentation.Document)] + public Dictionary D1 { get; set; } + [BsonDictionaryOptions(DictionaryRepresentation.ArrayOfArrays)] + public Dictionary D2 { get; set; } + [BsonDictionaryOptions(DictionaryRepresentation.ArrayOfDocuments)] + public Dictionary D3 { get; set; } + } + } +}