Skip to content

Commit

Permalink
Fix to #5522 - Model with nullable FK fails parent child query with "…
Browse files Browse the repository at this point in the history
…Argument Types do not match"

Problem was that during navigation rewrite we sometimes change types of expressions, e.g.

o.Customer.Id (originally int)

would get converted to:

(o != null) ? (int?)o.CustomerId : null (change type to int?)

We try to compensate for this later, by casting back to the original type, but we missed some cases: MemberAssignment, ElementInit, NewArray.

Fix is to add the compensation for those nodes.
  • Loading branch information
maumar committed Jul 8, 2016
1 parent bd1fa17 commit 8e7c9a9
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,42 @@ public virtual void Optional_navigation_type_compensation_works_with_projection_
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_DTOs()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Note != "K.I.A.").Select(t => new Squad { Id = t.Gear.SquadId });
var result = query.ToList();

Assert.Equal(5, result.Count);
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_list_initializers()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Note != "K.I.A.").Select(t => new List<int> { t.Gear.SquadId, t.Gear.SquadId + 1, 42 });
var result = query.ToList();

Assert.Equal(5, result.Count);
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_array_initializers()
{
using (var context = CreateContext())
{
var query = context.Tags.Where(t => t.Note != "K.I.A.").Select(t => new int[] { t.Gear.SquadId });
var result = query.ToList();

Assert.Equal(5, result.Count);
}
}

[ConditionalFact]
public virtual void Optional_navigation_type_compensation_works_with_orderby()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,53 @@ protected override Expression VisitMember(MemberExpression node)
?? base.VisitMember(node);
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
protected override MemberAssignment VisitMemberAssignment(MemberAssignment node)
{
var newExpression = CompensateForNullabilityDifference(
Visit(node.Expression),
node.Expression.Type);

return node.Update(newExpression);
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
protected override ElementInit VisitElementInit(ElementInit node)
{
var originalArgumentTypes = node.Arguments.Select(a => a.Type).ToList();
var newArguments = node.Arguments.Select(Visit).ToList();

for (var i = 0; i < newArguments.Count; i++)
{
newArguments[i] = CompensateForNullabilityDifference(newArguments[i], originalArgumentTypes[i]);
}

return node.Update(newArguments);
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
protected override Expression VisitNewArray(NewArrayExpression node)
{
var originalExpressionTypes = node.Expressions.Select(e => e.Type).ToList();
var newExpressions = node.Expressions.Select(Visit).ToList();

for (var i = 0; i < newExpressions.Count; i++)
{
newExpressions[i] = CompensateForNullabilityDifference(newExpressions[i], originalExpressionTypes[i]);
}

return node.Update(newExpressions);
}

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
Expand Down Expand Up @@ -478,7 +525,6 @@ private Expression RewriteNavigationProperties(
return default(Expression);
}


private static Expression CreateForeignKeyMemberAccess(string propertyName, Expression declaringExpression, INavigation navigation)
{
var principalKey = navigation.ForeignKey.PrincipalKey;
Expand Down Expand Up @@ -927,6 +973,20 @@ private static readonly MethodInfo _createEntityQueryableMethod
private static EntityQueryable<TResult> _CreateEntityQueryable<TResult>(IAsyncQueryProvider entityQueryProvider)
=> new EntityQueryable<TResult>(entityQueryProvider);

private static Expression CompensateForNullabilityDifference(Expression expression, Type originalType)
{
var newType = expression.Type;

var needsTypeCompensation = (originalType != expression.Type)
&& !originalType.IsNullableType()
&& newType.IsNullableType()
&& (originalType.UnwrapNullableType() == newType.UnwrapNullableType());

return needsTypeCompensation
? Expression.Convert(expression, originalType)
: expression;
}

private class NavigationRewritingQueryModelVisitor : ExpressionTransformingQueryModelVisitor
{
private readonly SubqueryInjector _subqueryInjector;
Expand Down Expand Up @@ -970,19 +1030,11 @@ public override void VisitOrderByClause(OrderByClause orderByClause, QueryModel

base.VisitOrderByClause(orderByClause, queryModel, index);

var newTypes = orderByClause.Orderings.Select(o => o.Expression.Type).ToList();

Debug.Assert(originalTypes.Count == newTypes.Count);

for (var i = 0; i < newTypes.Count; i++)
for (var i = 0; i < orderByClause.Orderings.Count; i++)
{
if ((originalTypes[i] != newTypes[i])
&& !originalTypes[i].IsNullableType()
&& newTypes[i].IsNullableType()
&& (originalTypes[i].UnwrapNullableType() == newTypes[i].UnwrapNullableType()))
{
orderByClause.Orderings[i].Expression = Expression.Convert(orderByClause.Orderings[i].Expression, originalTypes[i]);
}
orderByClause.Orderings[i].Expression = CompensateForNullabilityDifference(
orderByClause.Orderings[i].Expression,
originalTypes[i]);
}
}

Expand Down Expand Up @@ -1030,14 +1082,7 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que

base.VisitSelectClause(selectClause, queryModel);

var newType = selectClause.Selector.Type;
if ((originalType != newType)
&& !originalType.IsNullableType()
&& newType.IsNullableType()
&& (originalType.UnwrapNullableType() == newType.UnwrapNullableType()))
{
selectClause.Selector = Expression.Convert(selectClause.Selector, originalType);
}
selectClause.Selector = CompensateForNullabilityDifference(selectClause.Selector, originalType);
}

public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index)
Expand Down Expand Up @@ -1090,24 +1135,13 @@ public override void VisitResultOperator(ResultOperatorBase resultOperator, Quer

base.VisitResultOperator(resultOperator, queryModel, index);

var newKeySelectorType = groupResultOperator.KeySelector.Type;
var newElementSelectorType = groupResultOperator.ElementSelector.Type;

if (originalKeySelectorType != newKeySelectorType
&& !originalKeySelectorType.IsNullableType()
&& newKeySelectorType.IsNullableType()
&& originalKeySelectorType.UnwrapNullableType() == newKeySelectorType.UnwrapNullableType())
{
groupResultOperator.KeySelector = Expression.Convert(groupResultOperator.KeySelector, originalKeySelectorType);
}
groupResultOperator.KeySelector = CompensateForNullabilityDifference(
groupResultOperator.KeySelector,
originalKeySelectorType);

if (originalElementSelectorType != newElementSelectorType
&& !originalElementSelectorType.IsNullableType()
&& newElementSelectorType.IsNullableType()
&& originalElementSelectorType.UnwrapNullableType() == newElementSelectorType.UnwrapNullableType())
{
groupResultOperator.ElementSelector = Expression.Convert(groupResultOperator.ElementSelector, originalElementSelectorType);
}
groupResultOperator.ElementSelector = CompensateForNullabilityDifference(
groupResultOperator.ElementSelector,
originalElementSelectorType);

return;
}
Expand All @@ -1124,16 +1158,9 @@ private void VisitAndAdjustResultOperatorType<TResultOperator>(
var originalExpression = expressionExtractor(resultOperator);
var originalType = originalExpression.Type;

var translatedExpression = TransformingVisitor.Visit(originalExpression);

var newType = translatedExpression.Type;
if ((originalType != newType)
&& !originalType.IsNullableType()
&& newType.IsNullableType()
&& (originalType.UnwrapNullableType() == newType.UnwrapNullableType()))
{
translatedExpression = Expression.Convert(translatedExpression, originalType);
}
var translatedExpression = CompensateForNullabilityDifference(
TransformingVisitor.Visit(originalExpression),
originalType);

adjuster(resultOperator, translatedExpression);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,45 @@ LEFT JOIN [Gear] AS [t.Gear] ON ([t].[GearNickName] = [t.Gear].[Nickname]) AND (
Sql);
}

public override void Optional_navigation_type_compensation_works_with_DTOs()
{
base.Optional_navigation_type_compensation_works_with_DTOs();

Assert.Equal(
@"SELECT [t].[Id], [t].[GearNickName], [t].[GearSquadId], [t].[Note], [t.Gear].[Nickname], [t.Gear].[SquadId], [t.Gear].[AssignedCityName], [t.Gear].[CityOrBirthName], [t.Gear].[Discriminator], [t.Gear].[FullName], [t.Gear].[HasSoulPatch], [t.Gear].[LeaderNickname], [t.Gear].[LeaderSquadId], [t.Gear].[Rank]
FROM [CogTag] AS [t]
LEFT JOIN [Gear] AS [t.Gear] ON ([t].[GearNickName] = [t.Gear].[Nickname]) AND ([t].[GearSquadId] = [t.Gear].[SquadId])
WHERE ([t].[Note] <> N'K.I.A.') OR [t].[Note] IS NULL
ORDER BY [t].[GearNickName], [t].[GearSquadId]",
Sql);
}

public override void Optional_navigation_type_compensation_works_with_list_initializers()
{
base.Optional_navigation_type_compensation_works_with_list_initializers();

Assert.Equal(
@"SELECT [t].[Id], [t].[GearNickName], [t].[GearSquadId], [t].[Note], [t.Gear].[Nickname], [t.Gear].[SquadId], [t.Gear].[AssignedCityName], [t.Gear].[CityOrBirthName], [t.Gear].[Discriminator], [t.Gear].[FullName], [t.Gear].[HasSoulPatch], [t.Gear].[LeaderNickname], [t.Gear].[LeaderSquadId], [t.Gear].[Rank], 1
FROM [CogTag] AS [t]
LEFT JOIN [Gear] AS [t.Gear] ON ([t].[GearNickName] = [t.Gear].[Nickname]) AND ([t].[GearSquadId] = [t.Gear].[SquadId])
WHERE ([t].[Note] <> N'K.I.A.') OR [t].[Note] IS NULL
ORDER BY [t].[GearNickName], [t].[GearSquadId]",
Sql);
}

public override void Optional_navigation_type_compensation_works_with_array_initializers()
{
base.Optional_navigation_type_compensation_works_with_array_initializers();

Assert.Equal(
@"SELECT [t].[Id], [t].[GearNickName], [t].[GearSquadId], [t].[Note], [t.Gear].[Nickname], [t.Gear].[SquadId], [t.Gear].[AssignedCityName], [t.Gear].[CityOrBirthName], [t.Gear].[Discriminator], [t.Gear].[FullName], [t.Gear].[HasSoulPatch], [t.Gear].[LeaderNickname], [t.Gear].[LeaderSquadId], [t.Gear].[Rank]
FROM [CogTag] AS [t]
LEFT JOIN [Gear] AS [t.Gear] ON ([t].[GearNickName] = [t.Gear].[Nickname]) AND ([t].[GearSquadId] = [t.Gear].[SquadId])
WHERE ([t].[Note] <> N'K.I.A.') OR [t].[Note] IS NULL
ORDER BY [t].[GearNickName], [t].[GearSquadId]",
Sql);
}

public override void Optional_navigation_type_compensation_works_with_orderby()
{
base.Optional_navigation_type_compensation_works_with_orderby();
Expand Down

0 comments on commit 8e7c9a9

Please sign in to comment.