Skip to content

Commit

Permalink
Simpler/safer patch fix for #35393 that doesn't compensate for some c…
Browse files Browse the repository at this point in the history
…hanges
  • Loading branch information
maumar committed Jan 10, 2025
1 parent cc16006 commit 4343d68
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 35 deletions.
30 changes: 29 additions & 1 deletion src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ namespace Microsoft.EntityFrameworkCore.Query;
/// <inheritdoc />
public class SqlExpressionFactory : ISqlExpressionFactory
{
private static readonly bool UseOldBehavior35393 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35393", out var enabled35393) && enabled35393;

private readonly IRelationalTypeMappingSource _typeMappingSource;
private readonly RelationalTypeMapping _boolTypeMapping;

Expand Down Expand Up @@ -660,20 +663,45 @@ private SqlExpression Not(SqlExpression operand, SqlExpression? existingExpressi
SqlBinaryExpression { OperatorType: ExpressionType.OrElse } binary
=> AndAlso(Not(binary.Left), Not(binary.Right)),

// use equality where possible
// use equality where possible - we can only do this when we know a is not null
// at this point we are limited to constants, parameters and columns
// more comprehensive optimization is done during null expansion
// !(a == true) -> a == false
// !(a == false) -> a == true
SqlBinaryExpression { OperatorType: ExpressionType.Equal, Right: SqlConstantExpression { Value: bool } } binary
when UseOldBehavior35393
=> Equal(binary.Left, Not(binary.Right)),

SqlBinaryExpression
{
OperatorType: ExpressionType.Equal,
Right: SqlConstantExpression { Value: bool },
Left: SqlConstantExpression { Value: bool }
or SqlParameterExpression { IsNullable: false }
or ColumnExpression { IsNullable: false }
} binary
=> Equal(binary.Left, Not(binary.Right)),

// !(true == a) -> false == a
// !(false == a) -> true == a
SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: SqlConstantExpression { Value: bool } } binary
when UseOldBehavior35393
=> Equal(Not(binary.Left), binary.Right),

SqlBinaryExpression
{
OperatorType: ExpressionType.Equal,
Left: SqlConstantExpression { Value: bool },
Right: SqlConstantExpression { Value: bool }
or SqlParameterExpression { IsNullable: false }
or ColumnExpression { IsNullable: false }
} binary
=> Equal(Not(binary.Left), binary.Right),

// !(a == b) -> a != b
SqlBinaryExpression { OperatorType: ExpressionType.Equal } sqlBinaryOperand => NotEqual(
sqlBinaryOperand.Left, sqlBinaryOperand.Right),

// !(a != b) -> a == b
SqlBinaryExpression { OperatorType: ExpressionType.NotEqual } sqlBinaryOperand => Equal(
sqlBinaryOperand.Left, sqlBinaryOperand.Right),
Expand Down
58 changes: 49 additions & 9 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ namespace Microsoft.EntityFrameworkCore.Query;
/// </summary>
public class SqlNullabilityProcessor
{
private static readonly bool UseOldBehavior35393 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35393", out var enabled35393) && enabled35393;

private readonly List<ColumnExpression> _nonNullableColumns;
private readonly List<ColumnExpression> _nullValueColumns;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
Expand Down Expand Up @@ -1343,7 +1346,7 @@ protected virtual SqlExpression VisitSqlBinary(
// we assume that NullSemantics rewrite is only needed (on the current level)
// if the optimization didn't make any changes.
// Reason is that optimization can/will change the nullability of the resulting expression
// and that inforation is not tracked/stored anywhere
// and that information is not tracked/stored anywhere
// so we can no longer rely on nullabilities that we computed earlier (leftNullable, rightNullable)
// when performing null semantics rewrite.
// It should be fine because current optimizations *radically* change the expression
Expand Down Expand Up @@ -1844,9 +1847,17 @@ private SqlExpression RewriteNullSemantics(

var leftIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(left), leftNullable);
var leftIsNotNull = _sqlExpressionFactory.Not(leftIsNull);
if (!UseOldBehavior35393)
{
leftIsNotNull = OptimizeNotExpression(leftIsNotNull);
}

var rightIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(right), rightNullable);
var rightIsNotNull = _sqlExpressionFactory.Not(rightIsNull);
if (!UseOldBehavior35393)
{
rightIsNotNull = OptimizeNotExpression(rightIsNotNull);
}

SqlExpression body;
if (leftNegated == rightNegated)
Expand Down Expand Up @@ -1879,6 +1890,10 @@ private SqlExpression RewriteNullSemantics(
{
// the factory takes care of simplifying using DeMorgan
body = _sqlExpressionFactory.Not(body);
if (!UseOldBehavior35393)
{
body = OptimizeNotExpression(body);
}
}

return body;
Expand All @@ -1900,14 +1915,39 @@ protected virtual SqlExpression OptimizeNotExpression(SqlExpression expression)
// !(a >= b) -> a < b
// !(a < b) -> a >= b
// !(a <= b) -> a > b
if (sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand
&& TryNegate(sqlBinaryOperand.OperatorType, out var negated))
{
return _sqlExpressionFactory.MakeBinary(
negated,
sqlBinaryOperand.Left,
sqlBinaryOperand.Right,
sqlBinaryOperand.TypeMapping)!;
if (sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand)
{
if (TryNegate(sqlBinaryOperand.OperatorType, out var negated))
{
return _sqlExpressionFactory.MakeBinary(
negated,
sqlBinaryOperand.Left,
sqlBinaryOperand.Right,
sqlBinaryOperand.TypeMapping)!;
}

if (!UseOldBehavior35393)
{
// use equality where possible - at this point (true == null) and (false == null) have been converted to
// IS NULL / IS NOT NULL (i.e. false), so this optimization is safe to do. See #35393
// !(a == true) -> a == false
// !(a == false) -> a == true
if (sqlBinaryOperand is { OperatorType: ExpressionType.Equal, Right: SqlConstantExpression { Value: bool } })
{
return _sqlExpressionFactory.Equal(
sqlBinaryOperand.Left,
OptimizeNotExpression(_sqlExpressionFactory.Not(sqlBinaryOperand.Right)));
}

// !(true == a) -> false == a
// !(false == a) -> true == a
if (sqlBinaryOperand is { OperatorType: ExpressionType.Equal, Left: SqlConstantExpression { Value: bool } })
{
return _sqlExpressionFactory.Equal(
OptimizeNotExpression(_sqlExpressionFactory.Not(sqlBinaryOperand.Left)),
sqlBinaryOperand.Right);
}
}
}

// the factory can optimize most `NOT` expressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,33 @@ await AssertQueryScalar(
ss => ss.Set<NullSemanticsEntity1>().Where(e => true).Select(e => e.Id));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Compare_constant_true_to_nullable_column_negated(bool async)
=> await AssertQueryScalar(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => !(true == x.NullableBoolA)).Select(x => x.Id));


[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Compare_constant_true_to_non_nullable_column_negated(bool async)
=> await AssertQueryScalar(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => !(true == x.BoolA)).Select(x => x.Id));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Compare_constant_true_to_expression_which_evaluates_to_null(bool async)
{
var prm = default(bool?);

await AssertQueryScalar(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => x.NullableBoolA != null
&& !object.Equals(true, x.NullableBoolA == null ? null : prm)).Select(x => x.Id));
}

// We can't client-evaluate Like (for the expected results).
// However, since the test data has no LIKE wildcards, it effectively functions like equality - except that 'null like null' returns
// false instead of true. So we have this "lite" implementation which doesn't support wildcards.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,24 @@ join o in ss.Set<Order>().OrderBy(o => o.OrderID).Take(100) on c.CustomerID equa
from o in lo.Where(x => x.CustomerID.StartsWith("A"))
select new { c.CustomerID, o.OrderID });

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_on_true_equal_true(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().GroupJoin(
ss.Set<Order>(),
x => true,
x => true,
(c, g) => new { c, g })
.Select(x => new { x.c.CustomerID, Orders = x.g }),
elementSorter: e => e.CustomerID,
elementAsserter: (e, a) =>
{
Assert.Equal(e.CustomerID, a.CustomerID);
AssertCollection(e.Orders, a.Orders, elementSorter: ee => ee.OrderID);
});

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Inner_join_with_tautology_predicate_converts_to_cross_join(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5021,7 +5021,7 @@ INNER JOIN (
FROM [Factions] AS [f]
WHERE [f].[Name] = N'Swarm'
) AS [f0] ON [l].[Name] = [f0].[CommanderName]
WHERE [f0].[Eradicated] = CAST(0 AS bit) OR [f0].[Eradicated] IS NULL
WHERE [f0].[Eradicated] <> CAST(1 AS bit) OR [f0].[Eradicated] IS NULL
""");
}

Expand All @@ -5038,7 +5038,7 @@ LEFT JOIN (
FROM [Factions] AS [f]
WHERE [f].[Name] = N'Swarm'
) AS [f0] ON [l].[Name] = [f0].[CommanderName]
WHERE [f0].[Eradicated] = CAST(0 AS bit) OR [f0].[Eradicated] IS NULL
WHERE [f0].[Eradicated] <> CAST(1 AS bit) OR [f0].[Eradicated] IS NULL
""");
}

Expand Down Expand Up @@ -7585,17 +7585,17 @@ public override async Task Join_inner_source_custom_projection_followed_by_filte

AssertSql(
"""
SELECT CASE
WHEN [f].[Name] = N'Locust' THEN CAST(1 AS bit)
END AS [IsEradicated], [f].[CommanderName], [f].[Name]
FROM [LocustLeaders] AS [l]
INNER JOIN [Factions] AS [f] ON [l].[Name] = [f].[CommanderName]
WHERE CASE
WHEN [f].[Name] = N'Locust' THEN CAST(1 AS bit)
END = CAST(0 AS bit) OR CASE
WHEN [f].[Name] = N'Locust' THEN CAST(1 AS bit)
END IS NULL
""");
SELECT CASE
WHEN [f].[Name] = N'Locust' THEN CAST(1 AS bit)
END AS [IsEradicated], [f].[CommanderName], [f].[Name]
FROM [LocustLeaders] AS [l]
INNER JOIN [Factions] AS [f] ON [l].[Name] = [f].[CommanderName]
WHERE CASE
WHEN [f].[Name] = N'Locust' THEN CAST(1 AS bit)
END <> CAST(1 AS bit) OR CASE
WHEN [f].[Name] = N'Locust' THEN CAST(1 AS bit)
END IS NULL
""");
}

public override async Task Byte_array_contains_literal(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,22 @@ ORDER BY [o].[OrderID]
""");
}

public override async Task GroupJoin_on_true_equal_true(bool async)
{
await base.GroupJoin_on_true_equal_true(async);

AssertSql(
"""
SELECT [c].[CustomerID], [o0].[OrderID], [o0].[CustomerID], [o0].[EmployeeID], [o0].[OrderDate]
FROM [Customers] AS [c]
OUTER APPLY (
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
) AS [o0]
ORDER BY [c].[CustomerID]
""");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6824,7 +6824,7 @@ INNER JOIN (
FROM [LocustHordes] AS [l1]
WHERE [l1].[Name] = N'Swarm'
) AS [l2] ON [u].[Name] = [l2].[CommanderName]
WHERE [l2].[Eradicated] = CAST(0 AS bit) OR [l2].[Eradicated] IS NULL
WHERE [l2].[Eradicated] <> CAST(1 AS bit) OR [l2].[Eradicated] IS NULL
""");
}

Expand All @@ -6847,7 +6847,7 @@ LEFT JOIN (
FROM [LocustHordes] AS [l1]
WHERE [l1].[Name] = N'Swarm'
) AS [l2] ON [u].[Name] = [l2].[CommanderName]
WHERE [l2].[Eradicated] = CAST(0 AS bit) OR [l2].[Eradicated] IS NULL
WHERE [l2].[Eradicated] <> CAST(1 AS bit) OR [l2].[Eradicated] IS NULL
""");
}

Expand Down Expand Up @@ -10136,7 +10136,7 @@ FROM [LocustCommanders] AS [l0]
INNER JOIN [LocustHordes] AS [l1] ON [u].[Name] = [l1].[CommanderName]
WHERE CASE
WHEN [l1].[Name] = N'Locust' THEN CAST(1 AS bit)
END = CAST(0 AS bit) OR CASE
END <> CAST(1 AS bit) OR CASE
WHEN [l1].[Name] = N'Locust' THEN CAST(1 AS bit)
END IS NULL
""");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5766,7 +5766,7 @@ FROM [Factions] AS [f]
LEFT JOIN [LocustHordes] AS [l0] ON [f].[Id] = [l0].[Id]
WHERE [l0].[Id] IS NOT NULL AND [f].[Name] = N'Swarm'
) AS [s] ON [l].[Name] = [s].[CommanderName]
WHERE [s].[Eradicated] = CAST(0 AS bit) OR [s].[Eradicated] IS NULL
WHERE [s].[Eradicated] <> CAST(1 AS bit) OR [s].[Eradicated] IS NULL
""");
}

Expand All @@ -5786,7 +5786,7 @@ FROM [Factions] AS [f]
LEFT JOIN [LocustHordes] AS [l0] ON [f].[Id] = [l0].[Id]
WHERE [l0].[Id] IS NOT NULL AND [f].[Name] = N'Swarm'
) AS [s] ON [l].[Name] = [s].[CommanderName]
WHERE [s].[Eradicated] = CAST(0 AS bit) OR [s].[Eradicated] IS NULL
WHERE [s].[Eradicated] <> CAST(1 AS bit) OR [s].[Eradicated] IS NULL
""");
}

Expand Down Expand Up @@ -8602,7 +8602,7 @@ WHERE [l0].[Id] IS NOT NULL
) AS [s] ON [l].[Name] = [s].[CommanderName]
WHERE CASE
WHEN [s].[Name] = N'Locust' THEN CAST(1 AS bit)
END = CAST(0 AS bit) OR CASE
END <> CAST(1 AS bit) OR CASE
WHEN [s].[Name] = N'Locust' THEN CAST(1 AS bit)
END IS NULL
""");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2564,7 +2564,7 @@ SELECT CASE
INNER JOIN [Factions] FOR SYSTEM_TIME AS OF '2010-01-01T00:00:00.0000000' AS [f] ON [l].[Name] = [f].[CommanderName]
WHERE CASE
WHEN [f].[Name] = N'Locust' THEN CAST(1 AS bit)
END = CAST(0 AS bit) OR CASE
END <> CAST(1 AS bit) OR CASE
WHEN [f].[Name] = N'Locust' THEN CAST(1 AS bit)
END IS NULL
""");
Expand Down Expand Up @@ -4497,7 +4497,7 @@ INNER JOIN (
FROM [Factions] FOR SYSTEM_TIME AS OF '2010-01-01T00:00:00.0000000' AS [f]
WHERE [f].[Name] = N'Swarm'
) AS [f0] ON [l].[Name] = [f0].[CommanderName]
WHERE [f0].[Eradicated] = CAST(0 AS bit) OR [f0].[Eradicated] IS NULL
WHERE [f0].[Eradicated] <> CAST(1 AS bit) OR [f0].[Eradicated] IS NULL
""");
}

Expand Down Expand Up @@ -6446,7 +6446,7 @@ LEFT JOIN (
FROM [Factions] FOR SYSTEM_TIME AS OF '2010-01-01T00:00:00.0000000' AS [f]
WHERE [f].[Name] = N'Swarm'
) AS [f0] ON [l].[Name] = [f0].[CommanderName]
WHERE [f0].[Eradicated] = CAST(0 AS bit) OR [f0].[Eradicated] IS NULL
WHERE [f0].[Eradicated] <> CAST(1 AS bit) OR [f0].[Eradicated] IS NULL
""");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2874,7 +2874,7 @@ SELECT CASE
INNER JOIN "Factions" AS "f" ON "l"."Name" = "f"."CommanderName"
WHERE CASE
WHEN "f"."Name" = 'Locust' THEN 1
END = 0 OR CASE
END <> 1 OR CASE
WHEN "f"."Name" = 'Locust' THEN 1
END IS NULL
""");
Expand Down Expand Up @@ -5481,7 +5481,7 @@ LEFT JOIN (
FROM "Factions" AS "f"
WHERE "f"."Name" = 'Swarm'
) AS "f0" ON "l"."Name" = "f0"."CommanderName"
WHERE "f0"."Eradicated" = 0 OR "f0"."Eradicated" IS NULL
WHERE "f0"."Eradicated" <> 1 OR "f0"."Eradicated" IS NULL
""");
}

Expand Down Expand Up @@ -5549,7 +5549,7 @@ INNER JOIN (
FROM "Factions" AS "f"
WHERE "f"."Name" = 'Swarm'
) AS "f0" ON "l"."Name" = "f0"."CommanderName"
WHERE "f0"."Eradicated" = 0 OR "f0"."Eradicated" IS NULL
WHERE "f0"."Eradicated" <> 1 OR "f0"."Eradicated" IS NULL
""");
}

Expand Down
Loading

0 comments on commit 4343d68

Please sign in to comment.