Skip to content

Commit

Permalink
CSHARP-5071: Support string concatenation of mixed types.
Browse files Browse the repository at this point in the history
  • Loading branch information
rstam committed Jun 11, 2024
1 parent 5282cb0 commit fbcf93a
Show file tree
Hide file tree
Showing 5 changed files with 593 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ internal static class StringMethod
private static readonly MethodInfo __anyStringInWithParams;
private static readonly MethodInfo __anyStringNinWithEnumerable;
private static readonly MethodInfo __anyStringNinWithParams;
private static readonly MethodInfo __concatWith1Object;
private static readonly MethodInfo __concatWith2Objects;
private static readonly MethodInfo __concatWith3Objects;
private static readonly MethodInfo __concatWithObjectArray;
private static readonly MethodInfo __concatWith2Strings;
private static readonly MethodInfo __concatWith3Strings;
private static readonly MethodInfo __concatWith4Strings;
Expand Down Expand Up @@ -108,6 +112,10 @@ static StringMethod()
__anyStringInWithParams = ReflectionInfo.Method((IEnumerable<string> s, StringOrRegularExpression[] values) => s.AnyStringIn(values));
__anyStringNinWithEnumerable = ReflectionInfo.Method((IEnumerable<string> s, IEnumerable<StringOrRegularExpression> values) => s.AnyStringNin(values));
__anyStringNinWithParams = ReflectionInfo.Method((IEnumerable<string> s, StringOrRegularExpression[] values) => s.AnyStringNin(values));
__concatWith1Object = ReflectionInfo.Method((object arg) => string.Concat(arg));
__concatWith2Objects = ReflectionInfo.Method((object arg0, object arg1) => string.Concat(arg0, arg1));
__concatWith3Objects = ReflectionInfo.Method((object arg0, object arg1, object arg2) => string.Concat(arg0, arg1, arg2));
__concatWithObjectArray = ReflectionInfo.Method((object[] args) => string.Concat(args));
__concatWith2Strings = ReflectionInfo.Method((string str0, string str1) => string.Concat(str0, str1));
__concatWith3Strings = ReflectionInfo.Method((string str0, string str1, string str2) => string.Concat(str0, str1, str2));
__concatWith4Strings = ReflectionInfo.Method((string str0, string str1, string str2, string str3) => string.Concat(str0, str1, str2, str3));
Expand Down Expand Up @@ -168,6 +176,10 @@ static StringMethod()
public static MethodInfo AnyStringInWithParams => __anyStringInWithParams;
public static MethodInfo AnyStringNinWithEnumerable => __anyStringNinWithEnumerable;
public static MethodInfo AnyStringNinWithParams => __anyStringNinWithParams;
public static MethodInfo ConcatWith1Object => __concatWith1Object;
public static MethodInfo ConcatWith2Objects => __concatWith2Objects;
public static MethodInfo ConcatWith3Objects => __concatWith3Objects;
public static MethodInfo ConcatWithObjectArray => __concatWithObjectArray;
public static MethodInfo ConcatWith2Strings => __concatWith2Strings;
public static MethodInfo ConcatWith3Strings => __concatWith3Strings;
public static MethodInfo ConcatWith4Strings => __concatWith4Strings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators;
using MongoDB.Driver.Support;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
Expand All @@ -30,6 +31,11 @@ internal static class BinaryExpressionToAggregationExpressionTranslator
{
public static AggregationExpression Translate(TranslationContext context, BinaryExpression expression)
{
if (StringConcatMethodToAggregationExpressionTranslator.CanTranslate(expression, out var method, out var arguments))
{
return StringConcatMethodToAggregationExpressionTranslator.Translate(context, expression, method, arguments);
}

if (GetTypeComparisonExpressionToAggregationExpressionTranslator.CanTranslate(expression))
{
return GetTypeComparisonExpressionToAggregationExpressionTranslator.Translate(context, expression);
Expand Down Expand Up @@ -78,9 +84,7 @@ public static AggregationExpression Translate(TranslationContext context, Binary

var ast = expression.NodeType switch
{
ExpressionType.Add => IsStringConcatenationExpression(expression) ?
AstExpression.Concat(leftTranslation.Ast, rightTranslation.Ast) :
AstExpression.Add(leftTranslation.Ast, rightTranslation.Ast),
ExpressionType.Add => AstExpression.Add(leftTranslation.Ast, rightTranslation.Ast),
ExpressionType.And => expression.Type == typeof(bool) ?
AstExpression.And(leftTranslation.Ast, rightTranslation.Ast) :
AstExpression.BitAnd(leftTranslation.Ast, rightTranslation.Ast),
Expand Down Expand Up @@ -221,15 +225,6 @@ static bool IsEnumOrConvertEnumToUnderlyingType(Expression expression)
return expression.Type.IsEnum || IsConvertEnumToUnderlyingType(expression);
}

private static bool IsStringConcatenationExpression(BinaryExpression expression)
{
return
expression.NodeType == ExpressionType.Add &&
expression.Type == typeof(string) &&
expression.Left.Type == typeof(string) &&
expression.Right.Type == typeof(string);
}

private static AstBinaryOperator ToBinaryOperator(ExpressionType nodeType)
{
return nodeType switch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ public static AggregationExpression Translate(TranslationContext context, Method
return EnumerableConcatMethodToAggregationExpressionTranslator.Translate(context, expression);
}

if (StringConcatMethodToAggregationExpressionTranslator.CanTranslate(expression))
if (StringConcatMethodToAggregationExpressionTranslator.CanTranslate(expression, out var method, out var arguments))
{
return StringConcatMethodToAggregationExpressionTranslator.Translate(context, expression);
return StringConcatMethodToAggregationExpressionTranslator.Translate(context, expression, method, arguments);
}

throw new ExpressionNotSupportedException(expression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
*/

using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using MongoDB.Bson;
using MongoDB.Bson.IO;
using MongoDB.Bson.Serialization.Serializers;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
Expand All @@ -27,20 +30,48 @@ internal static class StringConcatMethodToAggregationExpressionTranslator
{
private static readonly MethodInfo[] __stringConcatMethods = new[]
{
StringMethod.ConcatWith1Object,
StringMethod.ConcatWith2Objects,
StringMethod.ConcatWith3Objects,
StringMethod.ConcatWithObjectArray,
StringMethod.ConcatWith2Strings,
StringMethod.ConcatWith3Strings,
StringMethod.ConcatWith4Strings,
StringMethod.ConcatWithStringArray
};

public static bool CanTranslate(MethodCallExpression expression)
=> expression.Method.IsOneOf(__stringConcatMethods);
public static bool CanTranslate(BinaryExpression expression, out MethodInfo method, out ReadOnlyCollection<Expression> arguments)
{
if (expression.NodeType == ExpressionType.Add &&
expression.Method != null &&
expression.Method.IsOneOf(StringMethod.ConcatWith2Objects, StringMethod.ConcatWith2Strings))
{
method = expression.Method;
arguments = new ReadOnlyCollection<Expression>(new[] { expression.Left, expression.Right });
return true;
}

method = null;
arguments = null;
return false;
}

public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
public static bool CanTranslate(MethodCallExpression expression, out MethodInfo method, out ReadOnlyCollection<Expression> arguments)
{
var method = expression.Method;
var arguments = expression.Arguments;
if (expression.Method.IsOneOf(__stringConcatMethods))
{
method = expression.Method;
arguments = expression.Arguments;
return true;
}

method = null;
arguments = null;
return false;
}

public static AggregationExpression Translate(TranslationContext context, Expression expression, MethodInfo method, ReadOnlyCollection<Expression> arguments)
{
IEnumerable<AstExpression> argumentsTranslations = null;

if (method.IsOneOf(
Expand All @@ -52,6 +83,16 @@ public static AggregationExpression Translate(TranslationContext context, Method
arguments.Select(a => ExpressionToAggregationExpressionTranslator.Translate(context, a).Ast);
}

if (method.IsOneOf(
StringMethod.ConcatWith1Object,
StringMethod.ConcatWith2Objects,
StringMethod.ConcatWith3Objects))
{
argumentsTranslations = arguments
.Select(a => ExpressionToAggregationExpressionTranslator.Translate(context, a))
.Select(ExpressionToString);
}

if (method.Is(StringMethod.ConcatWithStringArray))
{
var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, arguments.Single());
Expand All @@ -61,13 +102,61 @@ public static AggregationExpression Translate(TranslationContext context, Method
}
}

if (method.Is(StringMethod.ConcatWithObjectArray))
{
if (arguments.Single() is NewArrayExpression newArrayExpression)
{
argumentsTranslations = newArrayExpression.Expressions
.Select(a => ExpressionToAggregationExpressionTranslator.Translate(context, a))
.Select(ExpressionToString);
}
}

if (argumentsTranslations != null)
{
var ast = AstExpression.Concat(argumentsTranslations.ToArray());
return new AggregationExpression(expression, ast, StringSerializer.Instance);
}

throw new ExpressionNotSupportedException(expression);

static AstExpression ExpressionToString(AggregationExpression aggregationExpression)
{
var astExpression = aggregationExpression.Ast;
if (aggregationExpression.Serializer.ValueType == typeof(string))
{
return astExpression;
}
else
{
if (astExpression is AstConstantExpression constantAstExpression)
{
var value = constantAstExpression.Value;
var stringValue = ValueToString(aggregationExpression.Expression, value);
return AstExpression.Constant(stringValue);
}
else
{
return AstExpression.ToString(astExpression);
}
}
}

static string ValueToString(Expression expression, BsonValue value)
{
return value switch
{
BsonBoolean booleanValue => JsonConvert.ToString(booleanValue.Value),
BsonDateTime dateTimeValue => JsonConvert.ToString(dateTimeValue.ToUniversalTime()),
BsonDecimal128 decimalValue => JsonConvert.ToString(decimalValue.Value),
BsonDouble doubleValue => JsonConvert.ToString(doubleValue.Value),
BsonInt32 int32Value => JsonConvert.ToString(int32Value.Value),
BsonInt64 int64Value => JsonConvert.ToString(int64Value.Value),
BsonObjectId objectIdValue => objectIdValue.Value.ToString(),
BsonString stringValue => stringValue.Value,
_ => throw new ExpressionNotSupportedException(expression, because: $"values represented as BSON type {value.BsonType} are not supported by $toString")
};
}
}
}
}
Loading

0 comments on commit fbcf93a

Please sign in to comment.