Skip to content

Commit

Permalink
Translate non-aggregate string.Join
Browse files Browse the repository at this point in the history
Closes #28899
  • Loading branch information
roji committed Jun 4, 2024
1 parent 16acc46 commit b484f9b
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 6 deletions.
5 changes: 3 additions & 2 deletions src/EFCore.Relational/Query/ISqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,15 @@ SqlBinaryExpression Or(
RelationalTypeMapping? typeMapping = null);

// Other

/// <summary>
/// Creates a <see cref="SqlFunctionExpression" /> which represents a COALESCE operation.
/// Creates an expression which represents a SQL COALESCE operation.
/// </summary>
/// <param name="left">The left operand.</param>
/// <param name="right">The right operand.</param>
/// <param name="typeMapping">A type mapping to be assigned to the created expression.</param>
/// <returns>An expression representing a SQL COALESCE operation.</returns>
SqlFunctionExpression Coalesce(
SqlExpression Coalesce(
SqlExpression left,
SqlExpression right,
RelationalTypeMapping? typeMapping = null);
Expand Down
13 changes: 12 additions & 1 deletion src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,19 @@ public virtual SqlBinaryExpression Or(SqlExpression left, SqlExpression right, R
=> MakeBinary(ExpressionType.Or, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlFunctionExpression Coalesce(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Coalesce(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
{
// TODO: generalize this if we ever introduce nullability on SqlExpression, #33889
switch (left)
{
case SqlConstantExpression { Value: not null }:
case ColumnExpression { IsNullable: false }:
return left;

case SqlConstantExpression { Value: null }:
return right;
}

var resultType = right.Type;
var inferredTypeMapping = typeMapping
?? ExpressionExtensions.InferTypeMapping(left, right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Globalization;
using System.Text;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.SqlServer.Infrastructure.Internal;
using ExpressionExtensions = Microsoft.EntityFrameworkCore.Query.ExpressionExtensions;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
Expand All @@ -19,6 +20,8 @@ public class SqlServerSqlTranslatingExpressionVisitor : RelationalSqlTranslating
{
private readonly SqlServerQueryCompilationContext _queryCompilationContext;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly IRelationalTypeMappingSource _typeMappingSource;
private readonly int _sqlServerCompatibilityLevel;

private static readonly HashSet<string> DateTimeDataTypes
=
Expand Down Expand Up @@ -59,6 +62,9 @@ private static readonly MethodInfo StringEndsWithMethodInfo
private static readonly MethodInfo StringContainsMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string)])!;

private static readonly MethodInfo StringJoinMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.Join), [typeof(string), typeof(string[])])!;

private static readonly MethodInfo EscapeLikePatternParameterMethod =
typeof(SqlServerSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ConstructLikePatternParameter))!;

Expand All @@ -74,11 +80,14 @@ private static readonly MethodInfo StringContainsMethodInfo
public SqlServerSqlTranslatingExpressionVisitor(
RelationalSqlTranslatingExpressionVisitorDependencies dependencies,
SqlServerQueryCompilationContext queryCompilationContext,
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor,
ISqlServerSingletonOptions sqlServerSingletonOptions)
: base(dependencies, queryCompilationContext, queryableMethodTranslatingExpressionVisitor)
{
_queryCompilationContext = queryCompilationContext;
_sqlExpressionFactory = dependencies.SqlExpressionFactory;
_typeMappingSource = dependencies.TypeMappingSource;
_sqlServerCompatibilityLevel = sqlServerSingletonOptions.CompatibilityLevel;
}

/// <summary>
Expand Down Expand Up @@ -199,6 +208,60 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return translation3;
}

// Translate non-aggregate string.Join to CONCAT_WS (for aggregate string.Join, see SqlServerStringAggregateMethodTranslator)
if (method == StringJoinMethodInfo
&& methodCallExpression.Arguments[1] is NewArrayExpression newArrayExpression
&& _sqlServerCompatibilityLevel >= 140)
{
if (TranslationFailed(methodCallExpression.Arguments[0], Visit(methodCallExpression.Arguments[0]), out var delimiter))
{
return QueryCompilationContext.NotTranslatedExpression;
}

var arguments = new SqlExpression[newArrayExpression.Expressions.Count + 1];
arguments[0] = delimiter!;

var isUnicode = delimiter!.TypeMapping?.IsUnicode == true;

for (var i = 0; i < newArrayExpression.Expressions.Count; i++)
{
var argument = newArrayExpression.Expressions[i];
if (TranslationFailed(argument, Visit(argument), out var sqlArgument))
{
return QueryCompilationContext.NotTranslatedExpression;
}

// CONCAT_WS returns a type with a length that varies based on actual inputs (i.e. the sum of all argument lengths, plus
// the length needed for the delimiters). We don't know column values (or even parameter values, so we always return max.
// We do vary return varchar(max) or nvarchar(max) based on whether we saw any nvarchar mapping.
if (sqlArgument!.TypeMapping?.IsUnicode == true)
{
isUnicode = true;
}

// CONCAT_WS filters out nulls, but string.Join treats them as empty strings; so coalesce (which is a no-op for non-nullable
// arguments).
arguments[i + 1] = sqlArgument switch
{
ColumnExpression { IsNullable: false } => sqlArgument,
SqlConstantExpression constantExpression => constantExpression.Value is null
? _sqlExpressionFactory.Constant(string.Empty)
: constantExpression,
_ => Dependencies.SqlExpressionFactory.Coalesce(sqlArgument, _sqlExpressionFactory.Constant(string.Empty))
};
}

// CONCAT_WS never returns null; a null delimiter is interpreted as an empty string, and null arguments are skipped
// (but we coalesce them above in any case).
return Dependencies.SqlExpressionFactory.Function(
"CONCAT_WS",
arguments,
nullable: false,
argumentsPropagateNullability: new bool[arguments.Length],
typeof(string),
_typeMappingSource.FindMapping(isUnicode ? "nvarchar(max)" : "varchar(max)"));
}

return base.VisitMethodCall(methodCallExpression);

bool TryTranslateStartsEndsWithContains(
Expand Down Expand Up @@ -515,6 +578,20 @@ private Expression TranslateByteArrayElementAccess(Expression array, Expression
: QueryCompilationContext.NotTranslatedExpression;
}

[DebuggerStepThrough]
private static bool TranslationFailed(Expression? original, Expression? translation, out SqlExpression? castTranslation)
{
if (original != null
&& translation is not SqlExpression)
{
castTranslation = null;
return true;
}

castTranslation = translation as SqlExpression;
return false;
}

private static string? GetProviderType(SqlExpression expression)
=> expression.TypeMapping?.StoreType;
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.SqlServer.Infrastructure.Internal;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;

/// <summary>
Expand All @@ -11,16 +13,20 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
/// </summary>
public class SqlServerSqlTranslatingExpressionVisitorFactory : IRelationalSqlTranslatingExpressionVisitorFactory
{
private readonly ISqlServerSingletonOptions _sqlServerSingletonOptions;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public SqlServerSqlTranslatingExpressionVisitorFactory(
RelationalSqlTranslatingExpressionVisitorDependencies dependencies)
RelationalSqlTranslatingExpressionVisitorDependencies dependencies,
ISqlServerSingletonOptions sqlServerSingletonOptions)
{
Dependencies = dependencies;
_sqlServerSingletonOptions = sqlServerSingletonOptions;
}

/// <summary>
Expand All @@ -40,5 +46,6 @@ public virtual RelationalSqlTranslatingExpressionVisitor Create(
=> new SqlServerSqlTranslatingExpressionVisitor(
Dependencies,
(SqlServerQueryCompilationContext)queryCompilationContext,
queryableMethodTranslatingExpressionVisitor);
queryableMethodTranslatingExpressionVisitor,
_sqlServerSingletonOptions);
}
Original file line number Diff line number Diff line change
Expand Up @@ -1940,6 +1940,9 @@ public override Task String_Join_with_ordering(bool async)
public override Task String_Join_over_nullable_column(bool async)
=> AssertTranslationFailed(() => base.String_Join_over_nullable_column(async));

public override Task String_Join_non_aggregate(bool async)
=> AssertTranslationFailed(() => base.String_Join_non_aggregate(async));

public override Task String_Concat(bool async)
=> AssertTranslationFailed(() => base.String_Concat(async));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,17 @@ public virtual Task String_Join_over_nullable_column(bool async)
a.Regions.Split("|").OrderBy(id => id).ToArray());
});

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Join_non_aggregate(bool async)
{
var foo = "foo";

return AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => string.Join("|", new[] { c.CompanyName, foo, null, "bar" }) == "Around the Horn|foo||bar"));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Concat(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,21 @@ GROUP BY [c].[City]
""");
}

[SqlServerCondition(SqlServerCondition.SupportsFunctions2017)]
public override async Task String_Join_non_aggregate(bool async)
{
await base.String_Join_non_aggregate(async);

AssertSql(
"""
@__foo_0='foo' (Size = 4000)

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE CONCAT_WS(N'|', [c].[CompanyName], COALESCE(@__foo_0, N''), N'', N'bar') = N'Around the Horn|foo||bar'
""");
}

[SqlServerCondition(SqlServerCondition.SupportsFunctions2017)]
public override async Task String_Concat(bool async)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,9 @@ GROUP BY "c"."City"
""");
}

public override Task String_Join_non_aggregate(bool async)
=> AssertTranslationFailed(() => base.String_Join_non_aggregate(async));

public override async Task String_Concat(bool async)
{
await base.String_Concat(async);
Expand Down

0 comments on commit b484f9b

Please sign in to comment.