From d4a19a2fb147677097cfecff0e7001037bd94f40 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Sun, 16 Oct 2022 19:19:05 +0200 Subject: [PATCH] Null semantics for row value equality Closes #2357 --- .../Internal/NpgsqlSqlNullabilityProcessor.cs | 152 +++++++++++++++++- .../Query/NullSemanticsQueryNpgsqlTest.cs | 140 +++++++++++++++- 2 files changed, 287 insertions(+), 5 deletions(-) diff --git a/src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs b/src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs index 2e13595498..ffb4a7caa9 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs @@ -1,3 +1,4 @@ +using System.Runtime.CompilerServices; using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions; using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping; @@ -21,6 +22,153 @@ public NpgsqlSqlNullabilityProcessor( : base(dependencies, useRelationalNulls) => _sqlExpressionFactory = dependencies.SqlExpressionFactory; + /// + protected override SqlExpression VisitSqlBinary( + SqlBinaryExpression sqlBinaryExpression, + bool allowOptimizedExpansion, + out bool nullable) + { + if (sqlBinaryExpression is not + { + OperatorType: ExpressionType.Equal or ExpressionType.NotEqual, + Left: PostgresRowValueExpression leftRowValue, + Right: PostgresRowValueExpression rightRowValue + }) + { + return base.VisitSqlBinary(sqlBinaryExpression, allowOptimizedExpansion, out nullable); + } + + // Equality checks between row values require some special null semantics compensation. + // Row value equality/inequality works the same as regular equals/non-equals; this means that it's fine as long as we're comparing + // non-nullable values (no need to compensate), but for nullable values, we need to compensate. We go over the value pairs, and + // extract out pairs that require compensation to an expanded, non-value-tuple expression (regular equality null semantics). + // Note that PostgreSQL does have DISTINCT FROM/NOT DISTINCT FROM which would have been perfect here, but those still don't use + // indexes. + // Note that we don't do compensation for comparisons (e.g. greater than) since these are expressed via EF.Functions, which + // correspond directly to SQL constructs. + // The PG docs around this are in https://www.postgresql.org/docs/current/functions-comparisons.html#ROW-WISE-COMPARISON + + Check.DebugAssert(leftRowValue.Values.Count == rightRowValue.Values.Count, "left.Values.Count == right.Values.Count"); + var count = leftRowValue.Values.Count; + + var operatorType = sqlBinaryExpression.OperatorType; + + SqlExpression? expandedExpression = null; + List? visitedLeftValues = null; + List? visitedRightValues = null; + + for (var i = 0; i < count; i++) + { + // Visit the left value, populating visitedLeftValues only if we haven't yet switched to an expanded expression, and only if + // the visitation actually changed something, and + var leftValue = leftRowValue.Values[i]; + var rightValue = rightRowValue.Values[i]; + var visitedRightValue = Visit(rightRowValue.Values[i], out var rightNullable); + var visitedLeftValue = Visit(leftRowValue.Values[i], out var leftNullable); + + if (!leftNullable && !rightNullable + || allowOptimizedExpansion && (!leftNullable || !rightNullable)) + { + // The comparison for this value pair doesn't require expansion and can remain in the row value (so continue below). + // But if the visitation above changed a value, construct a list to hold the visited values. + if (visitedLeftValue != leftValue && visitedLeftValues is null) + { + visitedLeftValues = SliceToList(leftRowValue.Values, count, i); + } + + if (visitedLeftValues is not null) + { + visitedLeftValues.Add(visitedLeftValue); + } + + if (visitedRightValue != rightValue && visitedRightValues is null) + { + visitedRightValues = SliceToList(rightRowValue.Values, count, i); + } + + if (visitedRightValues is not null) + { + visitedRightValues.Add(visitedRightValue); + } + + continue; + } + + // If we're here, the value pair requires null semantics compensation. We build a binary expression around the pair and visit + // that (that adds the compensation). We then chain all such expressions together with AND. + var valueBinaryExpression = Visit( + _sqlExpressionFactory.MakeBinary(operatorType, visitedLeftValue, visitedRightValue, null)!, allowOptimizedExpansion, out _); + + if (expandedExpression is null) + { + // visitedLeft/RightValues will contain all pairs that can remain in the row value (since they don't require compensation) + visitedLeftValues = SliceToList(leftRowValue.Values, count, i); + visitedRightValues = SliceToList(rightRowValue.Values, count, i); + + expandedExpression = valueBinaryExpression; + } + else + { + expandedExpression = _sqlExpressionFactory.AndAlso(expandedExpression, valueBinaryExpression); + } + } + + nullable = false; + + if (expandedExpression is null) + { + // No pairs required compensation, so they all stay in the row value. + // Either return the original binary expression (if no value visitation changed anything), or construct a new one over the + // visited values. + return visitedLeftValues is null && visitedRightValues is null + ? sqlBinaryExpression + : _sqlExpressionFactory.MakeBinary( + operatorType, + visitedLeftValues is null + ? leftRowValue + : new PostgresRowValueExpression(visitedLeftValues, leftRowValue.Type, leftRowValue.TypeMapping), + visitedRightValues is null + ? rightRowValue + : new PostgresRowValueExpression(visitedRightValues, leftRowValue.Type, leftRowValue.TypeMapping), + null)!; + } + + Check.DebugAssert(visitedLeftValues is not null, "visitedLeftValues is not null"); + Check.DebugAssert(visitedRightValues is not null, "visitedRightValues is not null"); + + // Some pairs required compensation. Combine the pairs which didn't (in visitedLeft/RightValues) with expandedExpression + // (which contains the logic for those that did). + return visitedLeftValues.Count switch + { + 0 => expandedExpression, + 1 => _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.MakeBinary(operatorType, visitedLeftValues[0], visitedRightValues[0], null)!, + expandedExpression), + // Technically the CLR type and type mappings are incorrect, as we're truncating the row values. + // But that shouldn't matter. + _ => _sqlExpressionFactory.AndAlso( + _sqlExpressionFactory.MakeBinary( + sqlBinaryExpression.OperatorType, + new PostgresRowValueExpression(visitedLeftValues, leftRowValue.Type, leftRowValue.TypeMapping), + new PostgresRowValueExpression(visitedRightValues, rightRowValue.Type, rightRowValue.TypeMapping), + null)!, + expandedExpression) + }; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static List SliceToList(IReadOnlyList source, int capacity, int count) + { + var list = new List(capacity); + + for (var i = 0; i < count; i++) + { + list.Add(source[i]); + } + + return list; + } + } + /// protected override SqlExpression VisitCustomSqlExpression( SqlExpression sqlExpression, @@ -35,7 +183,7 @@ PostgresAllExpression postgresAllExpression PostgresArrayIndexExpression arrayIndexExpression => VisitArrayIndex(arrayIndexExpression, allowOptimizedExpansion, out nullable), PostgresBinaryExpression binaryExpression - => VisitBinary(binaryExpression, allowOptimizedExpansion, out nullable), + => VisitPostgresBinary(binaryExpression, allowOptimizedExpansion, out nullable), PostgresILikeExpression ilikeExpression => VisitILike(ilikeExpression, allowOptimizedExpansion, out nullable), PostgresJsonTraversalExpression postgresJsonTraversalExpression @@ -179,7 +327,7 @@ protected virtual SqlExpression VisitArrayIndex( /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - protected virtual SqlExpression VisitBinary(PostgresBinaryExpression binaryExpression, bool allowOptimizedExpansion, out bool nullable) + protected virtual SqlExpression VisitPostgresBinary(PostgresBinaryExpression binaryExpression, bool allowOptimizedExpansion, out bool nullable) { Check.NotNull(binaryExpression, nameof(binaryExpression)); diff --git a/test/EFCore.PG.FunctionalTests/Query/NullSemanticsQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/NullSemanticsQueryNpgsqlTest.cs index 8681aebf47..1620e4eef7 100644 --- a/test/EFCore.PG.FunctionalTests/Query/NullSemanticsQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/NullSemanticsQueryNpgsqlTest.cs @@ -6,9 +6,143 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query; // ReSharper disable once UnusedMember.Global public class NullSemanticsQueryNpgsqlTest : NullSemanticsQueryTestBase { - public NullSemanticsQueryNpgsqlTest(NullSemanticsQueryNpgsqlFixture fixture) + public NullSemanticsQueryNpgsqlTest(NullSemanticsQueryNpgsqlFixture fixture, ITestOutputHelper testOutputHelper) : base(fixture) - => Fixture.TestSqlLoggerFactory.Clear(); + { + Fixture.TestSqlLoggerFactory.Clear(); + Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Compare_row_values_equal_without_expansion(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => ValueTuple.Create(e.IntA, e.StringA).Equals(ValueTuple.Create(e.IntB, e.StringB))) + .Select(e => e.Id)); + + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => ValueTuple.Create(e.IntA, e.StringA).Equals(ValueTuple.Create(e.IntB, e.NullableStringB))) + .Select(e => e.Id)); + + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => ValueTuple.Create(e.IntA, e.NullableStringA).Equals(ValueTuple.Create(e.IntB, e.StringB))) + .Select(e => e.Id)); + + AssertSql( +""" +SELECT e."Id" +FROM "Entities1" AS e +WHERE (e."IntA", e."StringA") = (e."IntB", e."StringB") +""", + // +""" +SELECT e."Id" +FROM "Entities1" AS e +WHERE (e."IntA", e."StringA") = (e."IntB", e."NullableStringB") +""", + // +""" +SELECT e."Id" +FROM "Entities1" AS e +WHERE (e."IntA", e."NullableStringA") = (e."IntB", e."StringB") +"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Compare_row_values_equal_with_expansion(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => ValueTuple.Create(e.NullableStringA, e.IntA, e.BoolA).Equals(ValueTuple.Create(e.NullableStringB, e.IntB, e.BoolB))) + .Select(e => e.Id)); + + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => ValueTuple.Create(e.IntA, e.NullableStringA, e.BoolA).Equals(ValueTuple.Create(e.IntB, e.NullableStringB, e.BoolB))) + .Select(e => e.Id)); + + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => ValueTuple.Create(e.IntA, e.StringA, e.NullableBoolA).Equals(ValueTuple.Create(e.IntB, e.StringB, e.NullableBoolB))) + .Select(e => e.Id)); + + AssertSql( +""" +SELECT e."Id" +FROM "Entities1" AS e +WHERE (e."IntA", e."BoolA") = (e."IntB", e."BoolB") AND (e."NullableStringA" = e."NullableStringB" OR ((e."NullableStringA" IS NULL) AND (e."NullableStringB" IS NULL))) +""", + // +""" +SELECT e."Id" +FROM "Entities1" AS e +WHERE (e."IntA", e."BoolA") = (e."IntB", e."BoolB") AND (e."NullableStringA" = e."NullableStringB" OR ((e."NullableStringA" IS NULL) AND (e."NullableStringB" IS NULL))) +""", + // +""" +SELECT e."Id" +FROM "Entities1" AS e +WHERE (e."IntA", e."StringA") = (e."IntB", e."StringB") AND (e."NullableBoolA" = e."NullableBoolB" OR ((e."NullableBoolA" IS NULL) AND (e."NullableBoolB" IS NULL))) +"""); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Compare_row_values_not_equal(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !ValueTuple.Create(e.IntA, e.StringA).Equals(ValueTuple.Create(e.IntB, e.StringB))) + .Select(e => e.Id)); + + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !ValueTuple.Create(e.IntA, e.StringA).Equals(ValueTuple.Create(e.IntB, e.NullableStringB))) + .Select(e => e.Id)); + + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !ValueTuple.Create(e.IntA, e.NullableStringA).Equals(ValueTuple.Create(e.IntB, e.StringB))) + .Select(e => e.Id)); + + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !ValueTuple.Create(e.IntA, e.NullableStringA).Equals(ValueTuple.Create(e.IntB, e.NullableStringB))) + .Select(e => e.Id)); + + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !ValueTuple.Create(e.IntA, e.StringA, e.NullableBoolA).Equals(ValueTuple.Create(e.IntB, e.StringB, e.NullableBoolB))) + .Select(e => e.Id)); + + AssertSql( +""" +SELECT e."Id" +FROM "Entities1" AS e +WHERE (e."IntA", e."StringA") <> (e."IntB", e."StringB") +""", + // +""" +SELECT e."Id" +FROM "Entities1" AS e +WHERE e."IntA" <> e."IntB" OR e."StringA" <> e."NullableStringB" OR (e."NullableStringB" IS NULL) +""", + // +""" +SELECT e."Id" +FROM "Entities1" AS e +WHERE e."IntA" <> e."IntB" OR e."NullableStringA" <> e."StringB" OR (e."NullableStringA" IS NULL) +""", + // +""" +SELECT e."Id" +FROM "Entities1" AS e +WHERE e."IntA" <> e."IntB" OR ((e."NullableStringA" <> e."NullableStringB" OR (e."NullableStringA" IS NULL) OR (e."NullableStringB" IS NULL)) AND ((e."NullableStringA" IS NOT NULL) OR (e."NullableStringB" IS NOT NULL))) +""", + // +""" +SELECT e."Id" +FROM "Entities1" AS e +WHERE (e."IntA", e."StringA") <> (e."IntB", e."StringB") OR ((e."NullableBoolA" <> e."NullableBoolB" OR (e."NullableBoolA" IS NULL) OR (e."NullableBoolB" IS NULL)) AND ((e."NullableBoolA" IS NOT NULL) OR (e."NullableBoolB" IS NOT NULL))) +"""); + } + + private void AssertSql(params string[] expected) + => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); protected override void ClearLog() => Fixture.TestSqlLoggerFactory.Clear(); @@ -27,4 +161,4 @@ protected override NullSemanticsContext CreateContext(bool useRelationalNulls = return context; } -} \ No newline at end of file +}