Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize away Coalesce for trivial cases #34002

Merged
merged 4 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 73 additions & 73 deletions src/EFCore.Relational/Query/ISqlExpressionFactory.cs

Large diffs are not rendered by default.

126 changes: 66 additions & 60 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ private SqlExpression ApplyTypeMappingOnJsonScalar(
}

/// <inheritdoc />
public virtual SqlBinaryExpression? MakeBinary(
public virtual SqlExpression? MakeBinary(
ExpressionType operatorType,
SqlExpression left,
SqlExpression right,
Expand All @@ -416,125 +416,131 @@ private SqlExpression ApplyTypeMappingOnJsonScalar(
break;
}

return (SqlBinaryExpression)ApplyTypeMapping(
return ApplyTypeMapping(
new SqlBinaryExpression(operatorType, left, right, returnType, null), typeMapping);
}

/// <inheritdoc />
public virtual SqlBinaryExpression Equal(SqlExpression left, SqlExpression right)
public virtual SqlExpression Equal(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.Equal, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression NotEqual(SqlExpression left, SqlExpression right)
public virtual SqlExpression NotEqual(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.NotEqual, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression GreaterThan(SqlExpression left, SqlExpression right)
public virtual SqlExpression GreaterThan(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.GreaterThan, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression GreaterThanOrEqual(SqlExpression left, SqlExpression right)
public virtual SqlExpression GreaterThanOrEqual(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.GreaterThanOrEqual, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression LessThan(SqlExpression left, SqlExpression right)
public virtual SqlExpression LessThan(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.LessThan, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression LessThanOrEqual(SqlExpression left, SqlExpression right)
public virtual SqlExpression LessThanOrEqual(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.LessThanOrEqual, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression AndAlso(SqlExpression left, SqlExpression right)
public virtual SqlExpression AndAlso(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.AndAlso, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression OrElse(SqlExpression left, SqlExpression right)
public virtual SqlExpression OrElse(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.OrElse, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression Add(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Add(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.Add, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlBinaryExpression Subtract(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Subtract(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.Subtract, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlBinaryExpression Multiply(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Multiply(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.Multiply, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlBinaryExpression Divide(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Divide(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.Divide, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlBinaryExpression Modulo(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Modulo(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.Modulo, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlBinaryExpression And(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression And(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.And, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlBinaryExpression Or(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Or(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.Or, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlFunctionExpression Coalesce(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Coalesce(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
{
var resultType = right.Type;
var inferredTypeMapping = typeMapping
?? ExpressionExtensions.InferTypeMapping(left, right)
?? _typeMappingSource.FindMapping(resultType, Dependencies.Model);

var typeMappedArguments = new List<SqlExpression>
left = ApplyTypeMapping(left, inferredTypeMapping);
right = ApplyTypeMapping(right, inferredTypeMapping);

return left switch
{
ApplyTypeMapping(left, inferredTypeMapping), ApplyTypeMapping(right, inferredTypeMapping)
SqlConstantExpression { Value: null } => right,

SqlConstantExpression { Value: not null } or
ColumnExpression { IsNullable: false } => left,

_ => new SqlFunctionExpression(
"COALESCE",
[left, right],
nullable: true,
// COALESCE is handled separately since it's only nullable if *all* arguments are null
argumentsPropagateNullability: [false, false],
resultType,
inferredTypeMapping)
};

return new SqlFunctionExpression(
"COALESCE",
typeMappedArguments,
nullable: true,
// COALESCE is handled separately since it's only nullable if *all* arguments are null
argumentsPropagateNullability: [false, false],
resultType,
inferredTypeMapping);
}

/// <inheritdoc />
public virtual SqlUnaryExpression? MakeUnary(
public virtual SqlExpression? MakeUnary(
ExpressionType operatorType,
SqlExpression operand,
Type type,
RelationalTypeMapping? typeMapping = null)
=> SqlUnaryExpression.IsValidOperator(operatorType)
? (SqlUnaryExpression)ApplyTypeMapping(new SqlUnaryExpression(operatorType, operand, type, null), typeMapping)
? ApplyTypeMapping(new SqlUnaryExpression(operatorType, operand, type, null), typeMapping)
: null;

/// <inheritdoc />
public virtual SqlUnaryExpression IsNull(SqlExpression operand)
public virtual SqlExpression IsNull(SqlExpression operand)
=> MakeUnary(ExpressionType.Equal, operand, typeof(bool))!;

/// <inheritdoc />
public virtual SqlUnaryExpression IsNotNull(SqlExpression operand)
public virtual SqlExpression IsNotNull(SqlExpression operand)
=> MakeUnary(ExpressionType.NotEqual, operand, typeof(bool))!;

/// <inheritdoc />
public virtual SqlUnaryExpression Convert(SqlExpression operand, Type type, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Convert(SqlExpression operand, Type type, RelationalTypeMapping? typeMapping = null)
=> MakeUnary(ExpressionType.Convert, operand, type.UnwrapNullableType(), typeMapping)!;

/// <inheritdoc />
public virtual SqlUnaryExpression Not(SqlExpression operand)
public virtual SqlExpression Not(SqlExpression operand)
=> MakeUnary(ExpressionType.Not, operand, operand.Type, operand.TypeMapping)!;

/// <inheritdoc />
public virtual SqlUnaryExpression Negate(SqlExpression operand)
public virtual SqlExpression Negate(SqlExpression operand)
=> MakeUnary(ExpressionType.Negate, operand, operand.Type, operand.TypeMapping)!;

/// <inheritdoc />
public virtual CaseExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
public virtual SqlExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
{
var operandTypeMapping = operand!.TypeMapping
?? whenClauses.Select(wc => wc.Test.TypeMapping).FirstOrDefault(t => t != null)
Expand Down Expand Up @@ -563,7 +569,7 @@ public virtual CaseExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhe
}

/// <inheritdoc />
public virtual CaseExpression Case(IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
public virtual SqlExpression Case(IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
{
var resultTypeMapping = elseResult?.TypeMapping
?? whenClauses.Select(wc => wc.Result.TypeMapping).FirstOrDefault(t => t != null);
Expand All @@ -583,7 +589,7 @@ public virtual CaseExpression Case(IReadOnlyList<CaseWhenClause> whenClauses, Sq
}

/// <inheritdoc />
public virtual SqlFunctionExpression Function(
public virtual SqlExpression Function(
string name,
IEnumerable<SqlExpression> arguments,
bool nullable,
Expand All @@ -602,7 +608,7 @@ public virtual SqlFunctionExpression Function(
}

/// <inheritdoc />
public virtual SqlFunctionExpression Function(
public virtual SqlExpression Function(
string? schema,
string name,
IEnumerable<SqlExpression> arguments,
Expand All @@ -622,7 +628,7 @@ public virtual SqlFunctionExpression Function(
}

/// <inheritdoc />
public virtual SqlFunctionExpression Function(
public virtual SqlExpression Function(
SqlExpression instance,
string name,
IEnumerable<SqlExpression> arguments,
Expand All @@ -645,64 +651,64 @@ public virtual SqlFunctionExpression Function(
}

/// <inheritdoc />
public virtual SqlFunctionExpression NiladicFunction(
public virtual SqlExpression NiladicFunction(
string name,
bool nullable,
Type returnType,
RelationalTypeMapping? typeMapping = null)
=> new(name, nullable, returnType, typeMapping);
=> new SqlFunctionExpression(name, nullable, returnType, typeMapping);

/// <inheritdoc />
public virtual SqlFunctionExpression NiladicFunction(
public virtual SqlExpression NiladicFunction(
string schema,
string name,
bool nullable,
Type returnType,
RelationalTypeMapping? typeMapping = null)
=> new(schema, name, nullable, returnType, typeMapping);
=> new SqlFunctionExpression(schema, name, nullable, returnType, typeMapping);

/// <inheritdoc />
public virtual SqlFunctionExpression NiladicFunction(
public virtual SqlExpression NiladicFunction(
SqlExpression instance,
string name,
bool nullable,
bool instancePropagatesNullability,
Type returnType,
RelationalTypeMapping? typeMapping = null)
=> new(
=> new SqlFunctionExpression(
ApplyDefaultTypeMapping(instance), name, nullable, instancePropagatesNullability, returnType, typeMapping);

/// <inheritdoc />
public virtual ExistsExpression Exists(SelectExpression subquery)
=> new(subquery, _boolTypeMapping);
public virtual SqlExpression Exists(SelectExpression subquery)
=> new ExistsExpression(subquery, _boolTypeMapping);

/// <inheritdoc />
public virtual InExpression In(SqlExpression item, SelectExpression subquery)
public virtual SqlExpression In(SqlExpression item, SelectExpression subquery)
=> ApplyTypeMappingOnIn(new InExpression(item, subquery, _boolTypeMapping));

/// <inheritdoc />
public virtual InExpression In(SqlExpression item, IReadOnlyList<SqlExpression> values)
public virtual SqlExpression In(SqlExpression item, IReadOnlyList<SqlExpression> values)
=> ApplyTypeMappingOnIn(new InExpression(item, values, _boolTypeMapping));

/// <inheritdoc />
public virtual InExpression In(SqlExpression item, SqlParameterExpression valuesParameter)
public virtual SqlExpression In(SqlExpression item, SqlParameterExpression valuesParameter)
=> ApplyTypeMappingOnIn(new InExpression(item, valuesParameter, _boolTypeMapping));

/// <inheritdoc />
public virtual LikeExpression Like(SqlExpression match, SqlExpression pattern, SqlExpression? escapeChar = null)
=> (LikeExpression)ApplyDefaultTypeMapping(new LikeExpression(match, pattern, escapeChar, null));
public virtual SqlExpression Like(SqlExpression match, SqlExpression pattern, SqlExpression? escapeChar = null)
=> ApplyDefaultTypeMapping(new LikeExpression(match, pattern, escapeChar, null));

/// <inheritdoc />
public virtual SqlFragmentExpression Fragment(string sql)
=> new(sql);
public virtual SqlExpression Fragment(string sql)
=> new SqlFragmentExpression(sql);

/// <inheritdoc />
public virtual SqlConstantExpression Constant(object value, RelationalTypeMapping? typeMapping = null)
=> new(value, typeMapping);
public virtual SqlExpression Constant(object value, RelationalTypeMapping? typeMapping = null)
=> new SqlConstantExpression(value, typeMapping);

/// <inheritdoc />
public virtual SqlConstantExpression Constant(object? value, Type type, RelationalTypeMapping? typeMapping = null)
=> new(value, type, typeMapping);
public virtual SqlExpression Constant(object? value, Type type, RelationalTypeMapping? typeMapping = null)
=> new SqlConstantExpression(value, type, typeMapping);

/// <inheritdoc />
public virtual bool TryCreateLeast(
Expand Down
27 changes: 19 additions & 8 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt
subquery.Offset,
subquery.Limit);

var predicate = VisitSqlBinary(
var predicate = Visit(
_sqlExpressionFactory.Equal(subqueryProjection, item), allowOptimizedExpansion: true, out _);
subquery.ApplyPredicate(predicate);
subquery.ClearOrdering();
Expand Down Expand Up @@ -908,7 +908,7 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt
result,
(expr, nullableValue) => _sqlExpressionFactory.OrElse(
expr,
VisitSqlBinary(_sqlExpressionFactory.Equal(item, nullableValue), allowOptimizedExpansion, out _)));
Visit(_sqlExpressionFactory.Equal(item, nullableValue), allowOptimizedExpansion, out _)));

InExpression ProcessInExpressionValues(
InExpression inExpression,
Expand Down Expand Up @@ -1873,8 +1873,13 @@ private SqlExpression RewriteNullSemantics(
return sqlBinaryExpression.Update(left, right);
}

private SqlExpression SimplifyLogicalSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpression)
private SqlExpression SimplifyLogicalSqlBinaryExpression(SqlExpression expression)
{
if (expression is not SqlBinaryExpression sqlBinaryExpression)
{
return expression;
}

if (sqlBinaryExpression is
{
Left: SqlUnaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } leftUnary,
Expand Down Expand Up @@ -1930,13 +1935,14 @@ private SqlExpression SimplifyLogicalSqlBinaryExpression(SqlBinaryExpression sql
/// <summary>
/// Attempts to simplify a unary not operation on a non-nullable operand.
/// </summary>
/// <param name="sqlUnaryExpression">The expression to simplify.</param>
/// <param name="expression">The expression to simplify.</param>
/// <returns>The simplified expression, or the original expression if it cannot be simplified.</returns>
protected virtual SqlExpression OptimizeNonNullableNotExpression(SqlUnaryExpression sqlUnaryExpression)
protected virtual SqlExpression OptimizeNonNullableNotExpression(SqlExpression expression)
{
if (sqlUnaryExpression.OperatorType != ExpressionType.Not)
if (expression is not SqlUnaryExpression sqlUnaryExpression
|| sqlUnaryExpression.OperatorType != ExpressionType.Not)
{
return sqlUnaryExpression;
return expression;
}

switch (sqlUnaryExpression.Operand)
Expand Down Expand Up @@ -2207,8 +2213,13 @@ protected virtual TableExpressionBase UpdateParameterCollection(
SqlParameterExpression newCollectionParameter)
=> throw new InvalidOperationException();

private SqlExpression ProcessNullNotNull(SqlUnaryExpression sqlUnaryExpression, bool operandNullable)
private SqlExpression ProcessNullNotNull(SqlExpression sqlExpression, bool operandNullable)
{
if (sqlExpression is not SqlUnaryExpression sqlUnaryExpression)
{
return sqlExpression;
}

if (!operandNullable)
{
// when we know that operand is non-nullable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

// CONCAT_WS filters out nulls, but string.Join treats them as empty strings; so coalesce (which is a no-op for non-nullable
// arguments).
arguments[i + 1] = sqlArgument switch
{
ColumnExpression { IsNullable: false } => sqlArgument,
SqlConstantExpression constantExpression => constantExpression.Value is null
? _sqlExpressionFactory.Constant(string.Empty)
: constantExpression,
_ => Dependencies.SqlExpressionFactory.Coalesce(sqlArgument, _sqlExpressionFactory.Constant(string.Empty))
};
arguments[i + 1] = Dependencies.SqlExpressionFactory.Coalesce(sqlArgument, _sqlExpressionFactory.Constant(string.Empty));
}

// CONCAT_WS never returns null; a null delimiter is interpreted as an empty string, and null arguments are skipped
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public class SqlServerDateTimeMemberTranslator(
_ => null
};

SqlFunctionExpression DatePart(string part)
SqlExpression DatePart(string part)
=> sqlExpressionFactory.Function(
"DATEPART",
arguments: [sqlExpressionFactory.Fragment(part), instance!],
Expand Down
Loading