Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate non-aggregate string.Join to CONCAT_WS on SQL Server #28900

Merged
merged 1 commit into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Globalization;
using System.Text;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.SqlServer.Infrastructure.Internal;
using ExpressionExtensions = Microsoft.EntityFrameworkCore.Query.ExpressionExtensions;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
Expand All @@ -19,6 +20,8 @@ public class SqlServerSqlTranslatingExpressionVisitor : RelationalSqlTranslating
{
private readonly SqlServerQueryCompilationContext _queryCompilationContext;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly IRelationalTypeMappingSource _typeMappingSource;
private readonly int _sqlServerCompatibilityLevel;

private static readonly HashSet<string> DateTimeDataTypes
=
Expand Down Expand Up @@ -59,6 +62,9 @@ private static readonly MethodInfo StringEndsWithMethodInfo
private static readonly MethodInfo StringContainsMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string)])!;

private static readonly MethodInfo StringJoinMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.Join), [typeof(string), typeof(string[])])!;

private static readonly MethodInfo EscapeLikePatternParameterMethod =
typeof(SqlServerSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ConstructLikePatternParameter))!;

Expand All @@ -74,11 +80,14 @@ private static readonly MethodInfo StringContainsMethodInfo
public SqlServerSqlTranslatingExpressionVisitor(
RelationalSqlTranslatingExpressionVisitorDependencies dependencies,
SqlServerQueryCompilationContext queryCompilationContext,
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor,
ISqlServerSingletonOptions sqlServerSingletonOptions)
: base(dependencies, queryCompilationContext, queryableMethodTranslatingExpressionVisitor)
{
_queryCompilationContext = queryCompilationContext;
_sqlExpressionFactory = dependencies.SqlExpressionFactory;
_typeMappingSource = dependencies.TypeMappingSource;
_sqlServerCompatibilityLevel = sqlServerSingletonOptions.CompatibilityLevel;
}

/// <summary>
Expand Down Expand Up @@ -199,6 +208,60 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return translation3;
}

// Translate non-aggregate string.Join to CONCAT_WS (for aggregate string.Join, see SqlServerStringAggregateMethodTranslator)
if (method == StringJoinMethodInfo
&& methodCallExpression.Arguments[1] is NewArrayExpression newArrayExpression
&& _sqlServerCompatibilityLevel >= 140)
{
if (TranslationFailed(methodCallExpression.Arguments[0], Visit(methodCallExpression.Arguments[0]), out var delimiter))
{
return QueryCompilationContext.NotTranslatedExpression;
}

var arguments = new SqlExpression[newArrayExpression.Expressions.Count + 1];
arguments[0] = delimiter!;

var isUnicode = delimiter!.TypeMapping?.IsUnicode == true;

for (var i = 0; i < newArrayExpression.Expressions.Count; i++)
{
var argument = newArrayExpression.Expressions[i];
if (TranslationFailed(argument, Visit(argument), out var sqlArgument))
{
return QueryCompilationContext.NotTranslatedExpression;
}

// CONCAT_WS returns a type with a length that varies based on actual inputs (i.e. the sum of all argument lengths, plus
// the length needed for the delimiters). We don't know column values (or even parameter values, so we always return max.
// We do vary return varchar(max) or nvarchar(max) based on whether we saw any nvarchar mapping.
if (sqlArgument!.TypeMapping?.IsUnicode == true)
{
isUnicode = true;
}

// CONCAT_WS filters out nulls, but string.Join treats them as empty strings; so coalesce (which is a no-op for non-nullable
// arguments).
arguments[i + 1] = sqlArgument switch
{
ColumnExpression { IsNullable: false } => sqlArgument,
SqlConstantExpression constantExpression => constantExpression.Value is null
? _sqlExpressionFactory.Constant(string.Empty)
: constantExpression,
_ => Dependencies.SqlExpressionFactory.Coalesce(sqlArgument, _sqlExpressionFactory.Constant(string.Empty))
};
}

// CONCAT_WS never returns null; a null delimiter is interpreted as an empty string, and null arguments are skipped
// (but we coalesce them above in any case).
return Dependencies.SqlExpressionFactory.Function(
"CONCAT_WS",
arguments,
nullable: false,
argumentsPropagateNullability: new bool[arguments.Length],
typeof(string),
_typeMappingSource.FindMapping(isUnicode ? "nvarchar(max)" : "varchar(max)"));
}

return base.VisitMethodCall(methodCallExpression);

bool TryTranslateStartsEndsWithContains(
Expand Down Expand Up @@ -515,6 +578,20 @@ private Expression TranslateByteArrayElementAccess(Expression array, Expression
: QueryCompilationContext.NotTranslatedExpression;
}

[DebuggerStepThrough]
private static bool TranslationFailed(Expression? original, Expression? translation, out SqlExpression? castTranslation)
{
if (original != null
&& translation is not SqlExpression)
{
castTranslation = null;
return true;
}

castTranslation = translation as SqlExpression;
return false;
}

private static string? GetProviderType(SqlExpression expression)
=> expression.TypeMapping?.StoreType;
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.SqlServer.Infrastructure.Internal;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;

/// <summary>
Expand All @@ -11,16 +13,20 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
/// </summary>
public class SqlServerSqlTranslatingExpressionVisitorFactory : IRelationalSqlTranslatingExpressionVisitorFactory
{
private readonly ISqlServerSingletonOptions _sqlServerSingletonOptions;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public SqlServerSqlTranslatingExpressionVisitorFactory(
RelationalSqlTranslatingExpressionVisitorDependencies dependencies)
RelationalSqlTranslatingExpressionVisitorDependencies dependencies,
ISqlServerSingletonOptions sqlServerSingletonOptions)
{
Dependencies = dependencies;
_sqlServerSingletonOptions = sqlServerSingletonOptions;
}

/// <summary>
Expand All @@ -40,5 +46,6 @@ public virtual RelationalSqlTranslatingExpressionVisitor Create(
=> new SqlServerSqlTranslatingExpressionVisitor(
Dependencies,
(SqlServerQueryCompilationContext)queryCompilationContext,
queryableMethodTranslatingExpressionVisitor);
queryableMethodTranslatingExpressionVisitor,
_sqlServerSingletonOptions);
}
Original file line number Diff line number Diff line change
Expand Up @@ -1940,6 +1940,9 @@ public override Task String_Join_with_ordering(bool async)
public override Task String_Join_over_nullable_column(bool async)
=> AssertTranslationFailed(() => base.String_Join_over_nullable_column(async));

public override Task String_Join_non_aggregate(bool async)
=> AssertTranslationFailed(() => base.String_Join_non_aggregate(async));

public override Task String_Concat(bool async)
=> AssertTranslationFailed(() => base.String_Concat(async));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,17 @@ public virtual Task String_Join_over_nullable_column(bool async)
a.Regions.Split("|").OrderBy(id => id).ToArray());
});

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Join_non_aggregate(bool async)
{
var foo = "foo";

return AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => string.Join("|", new[] { c.CompanyName, foo, null, "bar" }) == "Around the Horn|foo||bar"));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task String_Concat(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,21 @@ GROUP BY [c].[City]
""");
}

[SqlServerCondition(SqlServerCondition.SupportsFunctions2017)]
public override async Task String_Join_non_aggregate(bool async)
{
await base.String_Join_non_aggregate(async);

AssertSql(
"""
@__foo_0='foo' (Size = 4000)

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE CONCAT_WS(N'|', [c].[CompanyName], COALESCE(@__foo_0, N''), N'', N'bar') = N'Around the Horn|foo||bar'
""");
}

[SqlServerCondition(SqlServerCondition.SupportsFunctions2017)]
public override async Task String_Concat(bool async)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,9 @@ GROUP BY "c"."City"
""");
}

public override Task String_Join_non_aggregate(bool async)
=> AssertTranslationFailed(() => base.String_Join_non_aggregate(async));

public override async Task String_Concat(bool async)
{
await base.String_Concat(async);
Expand Down