Skip to content

Commit

Permalink
Query: Translate Contains only when it is a server-side list
Browse files Browse the repository at this point in the history
Based on how values are expanded into InExpression values cannot be anything other than SqlConstant/SqlParameter

Resolves #18970

Apply fix for #17342 for cosmos
Also enable Contains tests in Cosmos which were already working
  • Loading branch information
smitpatel committed Dec 13, 2019
1 parent bdff0a1 commit 86fc11f
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 46 deletions.
17 changes: 12 additions & 5 deletions src/EFCore.Cosmos/Query/Internal/ContainsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,27 @@ public ContainsTranslator([NotNull] ISqlExpressionFactory sqlExpressionFactory)
public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadOnlyList<SqlExpression> arguments)
{
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains))
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)
&& ValidateValues(arguments[0]))
{
return _sqlExpressionFactory.In(arguments[1], arguments[0], false);
}

if ((method.DeclaringType.GetInterfaces().Contains(typeof(IList))
|| method.DeclaringType.IsGenericType
&& method.DeclaringType.GetGenericTypeDefinition() == typeof(ICollection<>))
&& string.Equals(method.Name, nameof(IList.Contains)))
if (method.Name == nameof(IList.Contains)
&& arguments.Count == 1
&& method.DeclaringType.GetInterfaces().Append(method.DeclaringType).Any(
t => t == typeof(IList)
|| (t.IsGenericType
&& t.GetGenericTypeDefinition() == typeof(ICollection<>)))
&& ValidateValues(instance))
{
return _sqlExpressionFactory.In(arguments[0], instance, false);
}

return null;
}

private bool ValidateValues(SqlExpression values)
=> values is SqlConstantExpression || values is SqlParameterExpression;
}
}
9 changes: 7 additions & 2 deletions src/EFCore.Relational/Query/Internal/ContainsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
Check.NotNull(arguments, nameof(arguments));

if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains))
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)
&& ValidateValues(arguments[0]))
{
return _sqlExpressionFactory.In(arguments[1], arguments[0], negated: false);
}
Expand All @@ -36,12 +37,16 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
&& method.DeclaringType.GetInterfaces().Append(method.DeclaringType).Any(
t => t == typeof(IList)
|| (t.IsGenericType
&& t.GetGenericTypeDefinition() == typeof(ICollection<>))))
&& t.GetGenericTypeDefinition() == typeof(ICollection<>)))
&& ValidateValues(instance))
{
return _sqlExpressionFactory.In(arguments[0], instance, negated: false);
}

return null;
}

private bool ValidateValues(SqlExpression values)
=> values is SqlConstantExpression || values is SqlParameterExpression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,7 @@ FROM root c
WHERE (c[""Discriminator""] = ""Customer"")");
}

[ConditionalFact(Skip = "Issue#17246 (Contains not implemented)")]
[ConditionalFact(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override void Contains_over_entityType_should_rewrite_to_identity_equality()
{
base.Contains_over_entityType_should_rewrite_to_identity_equality();
Expand All @@ -1284,7 +1284,7 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#17246 (Contains not implemented)")]
[ConditionalFact(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override async Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool async)
{
await base.List_Contains_over_entityType_should_rewrite_to_identity_equality(async);
Expand All @@ -1295,73 +1295,67 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#17246 (Contains not implemented)")]
public override async Task List_Contains_with_constant_list(bool async)
{
await base.List_Contains_with_constant_list(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
WHERE ((c[""Discriminator""] = ""Customer"") AND c[""CustomerID""] IN (""ALFKI"", ""ANATR""))");
}

[ConditionalTheory(Skip = "Issue#17246 (Contains not implemented)")]
public override async Task List_Contains_with_parameter_list(bool async)
{
await base.List_Contains_with_parameter_list(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
WHERE ((c[""Discriminator""] = ""Customer"") AND c[""CustomerID""] IN (""ALFKI"", ""ANATR""))");
}

[ConditionalTheory(Skip = "Issue#17246 (Contains not implemented)")]
public override async Task Contains_with_parameter_list_value_type_id(bool async)
{
await base.Contains_with_parameter_list_value_type_id(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
WHERE ((c[""Discriminator""] = ""Order"") AND c[""OrderID""] IN (10248, 10249))");
}

[ConditionalTheory(Skip = "Issue#17246 (Contains not implemented)")]
public override async Task Contains_with_constant_list_value_type_id(bool async)
{
await base.Contains_with_constant_list_value_type_id(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
WHERE ((c[""Discriminator""] = ""Order"") AND c[""OrderID""] IN (10248, 10249))");
}

[ConditionalTheory(Skip = "Issue#17246 (Contains not implemented)")]
public override async Task HashSet_Contains_with_parameter(bool async)
{
await base.HashSet_Contains_with_parameter(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
WHERE ((c[""Discriminator""] = ""Customer"") AND c[""CustomerID""] IN (""ALFKI""))");
}

[ConditionalTheory(Skip = "Issue#17246 (Contains not implemented)")]
public override async Task ImmutableHashSet_Contains_with_parameter(bool async)
{
await base.ImmutableHashSet_Contains_with_parameter(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
WHERE ((c[""Discriminator""] = ""Customer"") AND c[""CustomerID""] IN (""ALFKI""))");
}

[ConditionalFact(Skip = "Issue#17246 (Contains not implemented)")]
[ConditionalFact(Skip = "Issue#17246 (Contains over subquery is not supported)")]
public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();
Expand All @@ -1372,17 +1366,6 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalFact(Skip = "Issue #17246")]
public override void Contains_over_entityType_should_materialize_when_composite()
{
base.Contains_over_entityType_should_materialize_when_composite();

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""OrderDetail"") AND ((c[""OrderID""] = 10248) AND (c[""ProductID""] = 42)))");
}

public override async Task String_FirstOrDefault_in_projection_does_client_eval(bool async)
{
await base.String_FirstOrDefault_in_projection_does_client_eval(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@ public override void Value_conversion_with_property_named_value()
}

[ConditionalFact(Skip = "Issue#17050")]
public override void Collection_property_as_scalar()
public override void Collection_property_as_scalar_Any()
{
base.Collection_property_as_scalar();
base.Collection_property_as_scalar_Any();
}

[ConditionalFact(Skip = "Issue#17050")]
public override void Collection_enum_as_string_Contains()
{
base.Collection_enum_as_string_Contains();
}

public class CustomConvertersInMemoryFixture : CustomConvertersFixtureBase
Expand Down
63 changes: 52 additions & 11 deletions test/EFCore.Specification.Tests/CustomConvertersTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ public virtual void Can_query_and_update_with_nullable_converter_on_primary_key(
var principal = context.Add(
new NullablePrincipal
{
Id = 1, Dependents = new List<NonNullableDependent> { new NonNullableDependent { Id = 1 } }
Id = 1,
Dependents = new List<NonNullableDependent> { new NonNullableDependent { Id = 1 } }
})
.Entity;

Expand Down Expand Up @@ -520,14 +521,14 @@ protected class ValueWrapper
}

[ConditionalFact]
public virtual void Collection_property_as_scalar()
public virtual void Collection_property_as_scalar_Any()
{
using var context = CreateContext();
Assert.Equal(
@"The LINQ expression 'DbSet<CollectionScalar> .Where(c => c.Tags .Any())' could not be translated. Either rewrite the query in a form that can be translated, or switch to client evaluation explicitly by inserting a call to either AsEnumerable(), AsAsyncEnumerable(), ToList(), or ToListAsync(). See https://go.microsoft.com/fwlink/?linkid=2101038 for more information.",
Assert.Throws<InvalidOperationException>(
() => context.Set<CollectionScalar>().Where(e => e.Tags.Any()).ToList())
.Message.Replace("\r","").Replace("\n",""));
.Message.Replace("\r", "").Replace("\n", ""));
}

protected class CollectionScalar
Expand All @@ -536,6 +537,31 @@ protected class CollectionScalar
public List<string> Tags { get; set; }
}

[ConditionalFact]
public virtual void Collection_enum_as_string_Contains()
{
using var context = CreateContext();
var sameRole = Roles.Seller;
Assert.Equal(
@"The LINQ expression 'DbSet<CollectionEnum> .Where(c => c.Roles.Contains(__sameRole_0))' could not be translated. Either rewrite the query in a form that can be translated, or switch to client evaluation explicitly by inserting a call to either AsEnumerable(), AsAsyncEnumerable(), ToList(), or ToListAsync(). See https://go.microsoft.com/fwlink/?linkid=2101038 for more information.",
Assert.Throws<InvalidOperationException>(
() => context.Set<CollectionEnum>().Where(e => e.Roles.Contains(sameRole)).ToList())
.Message.Replace("\r", "").Replace("\n", ""));

}

protected class CollectionEnum
{
public int Id { get; set; }
public ICollection<Roles> Roles { get; set; }
}

protected enum Roles
{
Customer,
Seller
}

public abstract class CustomConvertersFixtureBase : BuiltInDataTypesFixtureBase
{
protected override string StoreName { get; } = "CustomConverters";
Expand Down Expand Up @@ -970,22 +996,27 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
b.Property(e => e.Tags).HasConversion(
c => string.Join(",", c),
s => s.Split(',', StringSplitOptions.None).ToList()).Metadata
.SetValueComparer(new ListOfStringComparer());
.SetValueComparer(new ValueComparer<List<string>>(favorStructuralComparisons: true));

b.HasData(new CollectionScalar
{
Id = 1,
Tags = new List<string> { "A", "B", "C" }
});
});
}

private class ListOfStringComparer : ValueComparer<List<string>>
{
public ListOfStringComparer()
: base(favorStructuralComparisons: true)
{
}
modelBuilder.Entity<CollectionEnum>(
b =>
{
b.Property(e => e.Roles).HasConversion(new RolesToStringConveter()).Metadata
.SetValueComparer(new ValueComparer<ICollection<Roles>>(favorStructuralComparisons: true));

b.HasData(new CollectionEnum
{
Id = 1,
Roles = new List<Roles> { Roles.Seller }
});
});
}

private static class StringToDictionarySerializer
Expand Down Expand Up @@ -1033,6 +1064,16 @@ public UrlSchemeRemover()
{
}
}

private class RolesToStringConveter : ValueConverter<ICollection<Roles>, string>
{
public RolesToStringConveter()
: base(v => string.Join(";", v.Select(f => f.ToString())),
v => v.Length > 0
? v.Split(new[] { ';' }).Select(f => (Roles)Enum.Parse(typeof(Roles), f)).ToList()
: new List<Roles>())
{ }
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ public virtual void Columns_have_expected_data_types()
BuiltInNullableDataTypesShadow.TestNullableUnsignedInt32 ---> [nullable bigint] [Precision = 19 Scale = 0]
BuiltInNullableDataTypesShadow.TestNullableUnsignedInt64 ---> [nullable decimal] [Precision = 20 Scale = 0]
BuiltInNullableDataTypesShadow.TestString ---> [nullable nvarchar] [MaxLength = -1]
CollectionEnum.Id ---> [int] [Precision = 10 Scale = 0]
CollectionEnum.Roles ---> [nullable nvarchar] [MaxLength = -1]
CollectionScalar.Id ---> [int] [Precision = 10 Scale = 0]
CollectionScalar.Tags ---> [nullable nvarchar] [MaxLength = -1]
EmailTemplate.Id ---> [uniqueidentifier]
Expand Down

0 comments on commit 86fc11f

Please sign in to comment.