Skip to content

Commit

Permalink
Null semantics for row value equality
Browse files Browse the repository at this point in the history
Closes #2357
  • Loading branch information
roji committed Oct 16, 2022
1 parent 95aa5c4 commit d4a19a2
Show file tree
Hide file tree
Showing 2 changed files with 287 additions and 5 deletions.
152 changes: 150 additions & 2 deletions src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Runtime.CompilerServices;
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions;
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;
Expand All @@ -21,6 +22,153 @@ public NpgsqlSqlNullabilityProcessor(
: base(dependencies, useRelationalNulls)
=> _sqlExpressionFactory = dependencies.SqlExpressionFactory;

/// <inheritdoc />
protected override SqlExpression VisitSqlBinary(
SqlBinaryExpression sqlBinaryExpression,
bool allowOptimizedExpansion,
out bool nullable)
{
if (sqlBinaryExpression is not
{
OperatorType: ExpressionType.Equal or ExpressionType.NotEqual,
Left: PostgresRowValueExpression leftRowValue,
Right: PostgresRowValueExpression rightRowValue
})
{
return base.VisitSqlBinary(sqlBinaryExpression, allowOptimizedExpansion, out nullable);
}

// Equality checks between row values require some special null semantics compensation.
// Row value equality/inequality works the same as regular equals/non-equals; this means that it's fine as long as we're comparing
// non-nullable values (no need to compensate), but for nullable values, we need to compensate. We go over the value pairs, and
// extract out pairs that require compensation to an expanded, non-value-tuple expression (regular equality null semantics).
// Note that PostgreSQL does have DISTINCT FROM/NOT DISTINCT FROM which would have been perfect here, but those still don't use
// indexes.
// Note that we don't do compensation for comparisons (e.g. greater than) since these are expressed via EF.Functions, which
// correspond directly to SQL constructs.
// The PG docs around this are in https://www.postgresql.org/docs/current/functions-comparisons.html#ROW-WISE-COMPARISON

Check.DebugAssert(leftRowValue.Values.Count == rightRowValue.Values.Count, "left.Values.Count == right.Values.Count");
var count = leftRowValue.Values.Count;

var operatorType = sqlBinaryExpression.OperatorType;

SqlExpression? expandedExpression = null;
List<SqlExpression>? visitedLeftValues = null;
List<SqlExpression>? visitedRightValues = null;

for (var i = 0; i < count; i++)
{
// Visit the left value, populating visitedLeftValues only if we haven't yet switched to an expanded expression, and only if
// the visitation actually changed something, and
var leftValue = leftRowValue.Values[i];
var rightValue = rightRowValue.Values[i];
var visitedRightValue = Visit(rightRowValue.Values[i], out var rightNullable);
var visitedLeftValue = Visit(leftRowValue.Values[i], out var leftNullable);

if (!leftNullable && !rightNullable
|| allowOptimizedExpansion && (!leftNullable || !rightNullable))
{
// The comparison for this value pair doesn't require expansion and can remain in the row value (so continue below).
// But if the visitation above changed a value, construct a list to hold the visited values.
if (visitedLeftValue != leftValue && visitedLeftValues is null)
{
visitedLeftValues = SliceToList(leftRowValue.Values, count, i);
}

if (visitedLeftValues is not null)
{
visitedLeftValues.Add(visitedLeftValue);
}

if (visitedRightValue != rightValue && visitedRightValues is null)
{
visitedRightValues = SliceToList(rightRowValue.Values, count, i);
}

if (visitedRightValues is not null)
{
visitedRightValues.Add(visitedRightValue);
}

continue;
}

// If we're here, the value pair requires null semantics compensation. We build a binary expression around the pair and visit
// that (that adds the compensation). We then chain all such expressions together with AND.
var valueBinaryExpression = Visit(
_sqlExpressionFactory.MakeBinary(operatorType, visitedLeftValue, visitedRightValue, null)!, allowOptimizedExpansion, out _);

if (expandedExpression is null)
{
// visitedLeft/RightValues will contain all pairs that can remain in the row value (since they don't require compensation)
visitedLeftValues = SliceToList(leftRowValue.Values, count, i);
visitedRightValues = SliceToList(rightRowValue.Values, count, i);

expandedExpression = valueBinaryExpression;
}
else
{
expandedExpression = _sqlExpressionFactory.AndAlso(expandedExpression, valueBinaryExpression);
}
}

nullable = false;

if (expandedExpression is null)
{
// No pairs required compensation, so they all stay in the row value.
// Either return the original binary expression (if no value visitation changed anything), or construct a new one over the
// visited values.
return visitedLeftValues is null && visitedRightValues is null
? sqlBinaryExpression
: _sqlExpressionFactory.MakeBinary(
operatorType,
visitedLeftValues is null
? leftRowValue
: new PostgresRowValueExpression(visitedLeftValues, leftRowValue.Type, leftRowValue.TypeMapping),
visitedRightValues is null
? rightRowValue
: new PostgresRowValueExpression(visitedRightValues, leftRowValue.Type, leftRowValue.TypeMapping),
null)!;
}

Check.DebugAssert(visitedLeftValues is not null, "visitedLeftValues is not null");
Check.DebugAssert(visitedRightValues is not null, "visitedRightValues is not null");

// Some pairs required compensation. Combine the pairs which didn't (in visitedLeft/RightValues) with expandedExpression
// (which contains the logic for those that did).
return visitedLeftValues.Count switch
{
0 => expandedExpression,
1 => _sqlExpressionFactory.AndAlso(
_sqlExpressionFactory.MakeBinary(operatorType, visitedLeftValues[0], visitedRightValues[0], null)!,
expandedExpression),
// Technically the CLR type and type mappings are incorrect, as we're truncating the row values.
// But that shouldn't matter.
_ => _sqlExpressionFactory.AndAlso(
_sqlExpressionFactory.MakeBinary(
sqlBinaryExpression.OperatorType,
new PostgresRowValueExpression(visitedLeftValues, leftRowValue.Type, leftRowValue.TypeMapping),
new PostgresRowValueExpression(visitedRightValues, rightRowValue.Type, rightRowValue.TypeMapping),
null)!,
expandedExpression)
};

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static List<SqlExpression> SliceToList(IReadOnlyList<SqlExpression> source, int capacity, int count)
{
var list = new List<SqlExpression>(capacity);

for (var i = 0; i < count; i++)
{
list.Add(source[i]);
}

return list;
}
}

/// <inheritdoc />
protected override SqlExpression VisitCustomSqlExpression(
SqlExpression sqlExpression,
Expand All @@ -35,7 +183,7 @@ PostgresAllExpression postgresAllExpression
PostgresArrayIndexExpression arrayIndexExpression
=> VisitArrayIndex(arrayIndexExpression, allowOptimizedExpansion, out nullable),
PostgresBinaryExpression binaryExpression
=> VisitBinary(binaryExpression, allowOptimizedExpansion, out nullable),
=> VisitPostgresBinary(binaryExpression, allowOptimizedExpansion, out nullable),
PostgresILikeExpression ilikeExpression
=> VisitILike(ilikeExpression, allowOptimizedExpansion, out nullable),
PostgresJsonTraversalExpression postgresJsonTraversalExpression
Expand Down Expand Up @@ -179,7 +327,7 @@ protected virtual SqlExpression VisitArrayIndex(
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected virtual SqlExpression VisitBinary(PostgresBinaryExpression binaryExpression, bool allowOptimizedExpansion, out bool nullable)
protected virtual SqlExpression VisitPostgresBinary(PostgresBinaryExpression binaryExpression, bool allowOptimizedExpansion, out bool nullable)
{
Check.NotNull(binaryExpression, nameof(binaryExpression));

Expand Down
140 changes: 137 additions & 3 deletions test/EFCore.PG.FunctionalTests/Query/NullSemanticsQueryNpgsqlTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,143 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query;
// ReSharper disable once UnusedMember.Global
public class NullSemanticsQueryNpgsqlTest : NullSemanticsQueryTestBase<NullSemanticsQueryNpgsqlFixture>
{
public NullSemanticsQueryNpgsqlTest(NullSemanticsQueryNpgsqlFixture fixture)
public NullSemanticsQueryNpgsqlTest(NullSemanticsQueryNpgsqlFixture fixture, ITestOutputHelper testOutputHelper)
: base(fixture)
=> Fixture.TestSqlLoggerFactory.Clear();
{
Fixture.TestSqlLoggerFactory.Clear();
Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Compare_row_values_equal_without_expansion(bool async)
{
await AssertQueryScalar(async, ss => ss.Set<NullSemanticsEntity1>()
.Where(e => ValueTuple.Create(e.IntA, e.StringA).Equals(ValueTuple.Create(e.IntB, e.StringB)))
.Select(e => e.Id));

await AssertQueryScalar(async, ss => ss.Set<NullSemanticsEntity1>()
.Where(e => ValueTuple.Create(e.IntA, e.StringA).Equals(ValueTuple.Create(e.IntB, e.NullableStringB)))
.Select(e => e.Id));

await AssertQueryScalar(async, ss => ss.Set<NullSemanticsEntity1>()
.Where(e => ValueTuple.Create(e.IntA, e.NullableStringA).Equals(ValueTuple.Create(e.IntB, e.StringB)))
.Select(e => e.Id));

AssertSql(
"""
SELECT e."Id"
FROM "Entities1" AS e
WHERE (e."IntA", e."StringA") = (e."IntB", e."StringB")
""",
//
"""
SELECT e."Id"
FROM "Entities1" AS e
WHERE (e."IntA", e."StringA") = (e."IntB", e."NullableStringB")
""",
//
"""
SELECT e."Id"
FROM "Entities1" AS e
WHERE (e."IntA", e."NullableStringA") = (e."IntB", e."StringB")
""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Compare_row_values_equal_with_expansion(bool async)
{
await AssertQueryScalar(async, ss => ss.Set<NullSemanticsEntity1>()
.Where(e => ValueTuple.Create(e.NullableStringA, e.IntA, e.BoolA).Equals(ValueTuple.Create(e.NullableStringB, e.IntB, e.BoolB)))
.Select(e => e.Id));

await AssertQueryScalar(async, ss => ss.Set<NullSemanticsEntity1>()
.Where(e => ValueTuple.Create(e.IntA, e.NullableStringA, e.BoolA).Equals(ValueTuple.Create(e.IntB, e.NullableStringB, e.BoolB)))
.Select(e => e.Id));

await AssertQueryScalar(async, ss => ss.Set<NullSemanticsEntity1>()
.Where(e => ValueTuple.Create(e.IntA, e.StringA, e.NullableBoolA).Equals(ValueTuple.Create(e.IntB, e.StringB, e.NullableBoolB)))
.Select(e => e.Id));

AssertSql(
"""
SELECT e."Id"
FROM "Entities1" AS e
WHERE (e."IntA", e."BoolA") = (e."IntB", e."BoolB") AND (e."NullableStringA" = e."NullableStringB" OR ((e."NullableStringA" IS NULL) AND (e."NullableStringB" IS NULL)))
""",
//
"""
SELECT e."Id"
FROM "Entities1" AS e
WHERE (e."IntA", e."BoolA") = (e."IntB", e."BoolB") AND (e."NullableStringA" = e."NullableStringB" OR ((e."NullableStringA" IS NULL) AND (e."NullableStringB" IS NULL)))
""",
//
"""
SELECT e."Id"
FROM "Entities1" AS e
WHERE (e."IntA", e."StringA") = (e."IntB", e."StringB") AND (e."NullableBoolA" = e."NullableBoolB" OR ((e."NullableBoolA" IS NULL) AND (e."NullableBoolB" IS NULL)))
""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Compare_row_values_not_equal(bool async)
{
await AssertQueryScalar(async, ss => ss.Set<NullSemanticsEntity1>()
.Where(e => !ValueTuple.Create(e.IntA, e.StringA).Equals(ValueTuple.Create(e.IntB, e.StringB)))
.Select(e => e.Id));

await AssertQueryScalar(async, ss => ss.Set<NullSemanticsEntity1>()
.Where(e => !ValueTuple.Create(e.IntA, e.StringA).Equals(ValueTuple.Create(e.IntB, e.NullableStringB)))
.Select(e => e.Id));

await AssertQueryScalar(async, ss => ss.Set<NullSemanticsEntity1>()
.Where(e => !ValueTuple.Create(e.IntA, e.NullableStringA).Equals(ValueTuple.Create(e.IntB, e.StringB)))
.Select(e => e.Id));

await AssertQueryScalar(async, ss => ss.Set<NullSemanticsEntity1>()
.Where(e => !ValueTuple.Create(e.IntA, e.NullableStringA).Equals(ValueTuple.Create(e.IntB, e.NullableStringB)))
.Select(e => e.Id));

await AssertQueryScalar(async, ss => ss.Set<NullSemanticsEntity1>()
.Where(e => !ValueTuple.Create(e.IntA, e.StringA, e.NullableBoolA).Equals(ValueTuple.Create(e.IntB, e.StringB, e.NullableBoolB)))
.Select(e => e.Id));

AssertSql(
"""
SELECT e."Id"
FROM "Entities1" AS e
WHERE (e."IntA", e."StringA") <> (e."IntB", e."StringB")
""",
//
"""
SELECT e."Id"
FROM "Entities1" AS e
WHERE e."IntA" <> e."IntB" OR e."StringA" <> e."NullableStringB" OR (e."NullableStringB" IS NULL)
""",
//
"""
SELECT e."Id"
FROM "Entities1" AS e
WHERE e."IntA" <> e."IntB" OR e."NullableStringA" <> e."StringB" OR (e."NullableStringA" IS NULL)
""",
//
"""
SELECT e."Id"
FROM "Entities1" AS e
WHERE e."IntA" <> e."IntB" OR ((e."NullableStringA" <> e."NullableStringB" OR (e."NullableStringA" IS NULL) OR (e."NullableStringB" IS NULL)) AND ((e."NullableStringA" IS NOT NULL) OR (e."NullableStringB" IS NOT NULL)))
""",
//
"""
SELECT e."Id"
FROM "Entities1" AS e
WHERE (e."IntA", e."StringA") <> (e."IntB", e."StringB") OR ((e."NullableBoolA" <> e."NullableBoolB" OR (e."NullableBoolA" IS NULL) OR (e."NullableBoolB" IS NULL)) AND ((e."NullableBoolA" IS NOT NULL) OR (e."NullableBoolB" IS NOT NULL)))
""");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

protected override void ClearLog()
=> Fixture.TestSqlLoggerFactory.Clear();
Expand All @@ -27,4 +161,4 @@ protected override NullSemanticsContext CreateContext(bool useRelationalNulls =

return context;
}
}
}

0 comments on commit d4a19a2

Please sign in to comment.