Skip to content

Commit

Permalink
Fix #1909 - Do not append Select when result operator performs projec…
Browse files Browse the repository at this point in the history
…tion.
  • Loading branch information
mikary committed Jan 7, 2016
1 parent a54fdf0 commit 8b0bdca
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 15 deletions.
22 changes: 16 additions & 6 deletions src/EntityFramework.Core/Query/EntityQueryModelVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using Remotion.Linq.Clauses;
using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Clauses.ExpressionVisitors;
using Remotion.Linq.Clauses.ResultOperators;
using Remotion.Linq.Clauses.StreamedData;

namespace Microsoft.Data.Entity.Query
Expand All @@ -33,6 +34,12 @@ public static readonly ParameterExpression QueryContextParameter
public static readonly MethodInfo PropertyMethodInfo
= typeof(EF).GetTypeInfo().GetDeclaredMethod(nameof(EF.Property));

private static readonly HashSet<Type> _projectingResultOperators = new HashSet<Type>
{
typeof(GroupResultOperator),
typeof(AllResultOperator)
};

private readonly IQueryOptimizer _queryOptimizer;
private readonly INavigationRewritingExpressionVisitorFactory _navigationRewritingExpressionVisitorFactory;
private readonly ISubQueryMemberPushDownExpressionVisitor _subQueryMemberPushDownExpressionVisitor;
Expand Down Expand Up @@ -795,12 +802,15 @@ var selector

if (selector.Type != sequenceType)
{
_expression
= Expression.Call(
LinqOperatorProvider.Select
.MakeGenericMethod(CurrentParameter.Type, selector.Type),
_expression,
Expression.Lambda(selector, CurrentParameter));
if (!queryModel.ResultOperators.Any(ro => _projectingResultOperators.Contains(ro.GetType())))
{
_expression
= Expression.Call(
LinqOperatorProvider.Select
.MakeGenericMethod(CurrentParameter.Type, selector.Type),
_expression,
Expression.Lambda(selector, CurrentParameter));
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/EntityFramework.Core/Query/ResultOperatorHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ var expression
Expression.Lambda(elementSelector, entityQueryModelVisitor.CurrentParameter));

entityQueryModelVisitor.CurrentParameter
= Expression.Parameter(sequenceType, groupResultOperator.ItemName);
= Expression.Parameter(expression.Type.GetSequenceType(), groupResultOperator.ItemName);

entityQueryModelVisitor
.AddOrUpdateMapping(groupResultOperator, entityQueryModelVisitor.CurrentParameter);
Expand Down
115 changes: 107 additions & 8 deletions test/EntityFramework.Core.FunctionalTests/QueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public virtual void Take_simple_projection()
cs => cs.OrderBy(c => c.CustomerID).Select(c => c.City).Take(10),
assertOrder: true);
}

[ConditionalFact]
public virtual void Take_subquery_projection()
{
Expand Down Expand Up @@ -283,7 +283,7 @@ public virtual void Any_nested_negated()

[ConditionalFact]
public virtual void Any_nested_negated2()
{
{
AssertQuery<Customer, Order>(
(cs, os) => cs.Where(c => c.City != "London"
&& !os.Any(o => o.CustomerID.StartsWith("A"))));
Expand All @@ -295,7 +295,7 @@ public virtual void Any_nested_negated3()
AssertQuery<Customer, Order>(
(cs, os) => cs.Where(c => !os.Any(o => o.CustomerID.StartsWith("A"))
&& c.City != "London"));
}
}

[ConditionalFact]
public virtual void Any_nested()
Expand Down Expand Up @@ -1027,7 +1027,7 @@ public virtual void Where_datetime_now()
cs => cs.Where(c => DateTime.Now != myDatetime),
entryCount: 91);
}

[ConditionalFact]
public virtual void Where_datetime_utcnow()
{
Expand Down Expand Up @@ -2301,7 +2301,7 @@ from c in cs
join o in os on new Foo { Bar = c.CustomerID } equals new Foo { Bar = o.CustomerID }
select new { c, o });
}

[ConditionalFact]
public virtual void Join_local_collection_int_closure_is_cached_correctly()
{
Expand Down Expand Up @@ -2944,8 +2944,6 @@ public virtual void GroupBy_with_element_selector_sum()
os.GroupBy(o => o.CustomerID, o => o.OrderID).Select(g => g.Sum()));
}



[ConditionalFact]
public virtual void GroupBy_with_element_selector()
{
Expand Down Expand Up @@ -3053,6 +3051,107 @@ public virtual void OrderBy_GroupBy_SelectMany_shadow()
.Select(g => EF.Property<string>(g, "Title")));
}

[ConditionalFact]
public virtual void Select_GroupBy()
{
AssertQuery<Order>(
os => os.Select(o => new ProjectedType
{
Order = o.OrderID,
Customer = o.CustomerID
})
.GroupBy(p => p.Customer),
asserter:
(l2oResults, efResults) =>
{
var efGroupings = efResults.Cast<IGrouping<string, ProjectedType>>().ToList();

foreach (IGrouping<string, ProjectedType> l2oGrouping in l2oResults)
{
var efGrouping = efGroupings.Single(efg => efg.Key == l2oGrouping.Key);

Assert.Equal(l2oGrouping.OrderBy(p => p.Order), efGrouping.OrderBy(p => p.Order));
}
});
}

[ConditionalFact]
public virtual void Select_GroupBy_SelectMany()
{
AssertQuery<Order>(
os => os.Select(o => new ProjectedType
{
Order = o.OrderID,
Customer = o.CustomerID
})
.GroupBy(o => o.Order)
.SelectMany(g => g));
}

[ConditionalFact]
public virtual void Select_All()
{
using (var context = CreateContext())
{
Assert.Equal(
false,
context
.Set<Order>()
.Select(o => new ProjectedType
{
Order = o.OrderID,
Customer = o.CustomerID
})
.All(p => p.Customer == "ALFKI")
);
}
}

[ConditionalFact]
public virtual void Select_GroupBy_All()
{
using (var context = CreateContext())
{
Assert.Equal(
false,
context
.Set<Order>()
.Select(o => new ProjectedType
{
Order = o.OrderID,
Customer = o.CustomerID
})
.GroupBy(a => a.Customer)
.All(a => a.Key == "ALFKI")
);
}
}

private class ProjectedType
{
public int Order { get; set; }
public string Customer { get; set; }

protected bool Equals(ProjectedType other) => string.Equals(Order, other.Order);

public override bool Equals(object obj)
{
if (ReferenceEquals(null, obj))
{
return false;
}
if (ReferenceEquals(this, obj))
{
return true;
}

return obj.GetType() == GetType()
&& Equals((ProjectedType)obj);
}

public override int GetHashCode() => Order.GetHashCode();
}

[ConditionalFact]
public virtual void Sum_with_no_arg()
{
Expand Down Expand Up @@ -4096,7 +4195,7 @@ public virtual void Contains_with_local_array_closure()
AssertQuery<Customer>(cs =>
cs.Where(c => ids.Contains(c.CustomerID)), entryCount: 1);

ids = new []{ "ABCDE" };
ids = new[] { "ABCDE" };

AssertQuery<Customer>(cs =>
cs.Where(c => ids.Contains(c.CustomerID)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2486,6 +2486,28 @@ FROM [Orders] AS [o]
Sql);
}

public override void Select_GroupBy()
{
base.Select_GroupBy();

Assert.Equal(
@"SELECT [o].[OrderID], [o].[CustomerID]
FROM [Orders] AS [o]
ORDER BY [o].[CustomerID]",
Sql);
}

public override void Select_GroupBy_SelectMany()
{
base.Select_GroupBy_SelectMany();

Assert.Equal(
@"SELECT [o].[OrderID], [o].[CustomerID]
FROM [Orders] AS [o]
ORDER BY [o].[OrderID]",
Sql);
}

public override void SelectMany_cartesian_product_with_ordering()
{
base.SelectMany_cartesian_product_with_ordering();
Expand Down

0 comments on commit 8b0bdca

Please sign in to comment.