Skip to content

Commit

Permalink
Support conditional expression in projection
Browse files Browse the repository at this point in the history
  • Loading branch information
smitpatel committed Jan 8, 2016
1 parent 7f97473 commit 2837445
Show file tree
Hide file tree
Showing 12 changed files with 285 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.Data.Entity.Query.Expressions;
using Microsoft.Data.Entity.Query.Expressions.Internal;
using Microsoft.Data.Entity.Storage;
using Microsoft.Data.Entity.Utilities;
using Remotion.Linq.Parsing;

namespace Microsoft.Data.Entity.Query.Sql.Internal
{
Expand All @@ -20,10 +22,10 @@ public SqlServerQuerySqlGenerator(
[NotNull] IRelationalTypeMapper relationalTypeMapper,
[NotNull] SelectExpression selectExpression)
: base(
relationalCommandBuilderFactory,
sqlGenerationHelper,
parameterNameGeneratorFactory,
relationalTypeMapper,
relationalCommandBuilderFactory,
sqlGenerationHelper,
parameterNameGeneratorFactory,
relationalTypeMapper,
selectExpression)
{
}
Expand Down Expand Up @@ -69,6 +71,14 @@ protected override void GenerateLimitOffset(SelectExpression selectExpression)
base.GenerateLimitOffset(selectExpression);
}

protected override void VisitProjection(IReadOnlyList<Expression> projections)
{
var comparisonTransformer = new ProjectionComparisonTransformingVisitor();
var transformedProjections = projections.Select(comparisonTransformer.Visit).ToList();

base.VisitProjection(transformedProjections);
}

public virtual Expression VisitRowNumber(RowNumberExpression rowNumberExpression)
{
Check.NotNull(rowNumberExpression, nameof(rowNumberExpression));
Expand All @@ -89,5 +99,59 @@ public override Expression VisitSqlFunction(SqlFunctionExpression sqlFunctionExp
}
return base.VisitSqlFunction(sqlFunctionExpression);
}

private class ProjectionComparisonTransformingVisitor : RelinqExpressionVisitor
{
protected override Expression VisitUnary(UnaryExpression node)
{
if (node.NodeType == ExpressionType.Not
&& node.Operand is AliasExpression)
{
return Expression.Condition(
node,
Expression.Constant(true, typeof(bool)),
Expression.Constant(false, typeof(bool)));
}

return base.VisitUnary(node);
}

protected override Expression VisitBinary(BinaryExpression node)
{
if (node.IsComparisonOperation())
{
return Expression.Condition(
node,
Expression.Constant(true, typeof(bool)),
Expression.Constant(false, typeof(bool)));
}

return base.VisitBinary(node);
}


protected override Expression VisitConditional(ConditionalExpression node)
{
var test = Visit(node.Test);
if (test is AliasExpression)
{
return Expression.Condition(
Expression.Equal(test, Expression.Constant(true, typeof(bool))),
Visit(node.IfTrue),
Visit(node.IfFalse));
}

var condition = test as ConditionalExpression;
if (condition != null)
{
return Expression.Condition(
condition.Test,
Visit(node.IfTrue),
Visit(node.IfFalse));
}
return base.VisitConditional(node);
}
}

}
}
12 changes: 12 additions & 0 deletions src/EntityFramework.Relational/Extensions/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ public static bool IsLogicalOperation([NotNull] this Expression expression)
|| (expression.NodeType == ExpressionType.OrElse);
}

public static bool IsComparisonOperation([NotNull] this Expression expression)
{
Check.NotNull(expression, nameof(expression));

return expression.NodeType == ExpressionType.Equal
|| expression.NodeType == ExpressionType.NotEqual
|| expression.NodeType == ExpressionType.LessThan
|| expression.NodeType == ExpressionType.LessThanOrEqual
|| expression.NodeType == ExpressionType.GreaterThan
|| expression.NodeType == ExpressionType.GreaterThanOrEqual;
}

public static ColumnExpression TryGetColumnExpression([NotNull] this Expression expression)
=> (expression as AliasExpression)?.TryGetColumnExpression();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ public virtual Expression VisitSelect(SelectExpression selectExpression)
_relationalCommandBuilder.Append(", ");
}

VisitJoin(selectExpression.Projection);
VisitProjection(selectExpression.Projection);

projectionAdded = true;
}
Expand Down Expand Up @@ -227,6 +227,13 @@ var predicate
return selectExpression;
}

protected virtual void VisitProjection([NotNull] IReadOnlyList<Expression> projections)
{
var nullComparisonTransformer = new NullComparisonTransformingVisitor(_parametersValues);

VisitJoin(projections.Select(e => nullComparisonTransformer.Visit(e)).ToList());
}

protected virtual void GenerateOrderBy([NotNull] IReadOnlyList<Ordering> orderings)
{
_relationalCommandBuilder.Append("ORDER BY ");
Expand Down Expand Up @@ -1169,4 +1176,4 @@ var columnExpression
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,13 @@ await AssertQuery<Customer>(
cs => cs.Select(c => new { c.City, c.Phone, c.Country }));
}

[ConditionalFact]
public virtual async Task Select_anonymous_conditional_expression()
{
await AssertQuery<Product>(
ps => ps.Select(p => new { p.ProductID, IsAvailable = p.UnitsInStock > 0 }));
}

[ConditionalFact]
public virtual async Task Select_customer_table()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,72 @@ public virtual void Where_nullable_enum_with_nullable_parameter()
}
}

[ConditionalFact]
public virtual void Select_inverted_boolean()
{
using (var context = CreateContext())
{
var automaticWeapons = context.Weapons
.Where(w => w.IsAutomatic)
.Select(w => new { w.Id, Manual = !w.IsAutomatic })
.ToList();

Assert.True(automaticWeapons.All(t => t.Manual == false));
}
}

[ConditionalFact]
public virtual void Select_comparison_with_null()
{
AmmunitionType? ammunitionType = AmmunitionType.Cartridge;
using (var context = CreateContext())
{
var cartidgeWeapons = context.Weapons
.Where(w => w.AmmunitionType == ammunitionType)
.Select(w => new { w.Id, Cartidge = w.AmmunitionType == ammunitionType })
.ToList();

Assert.True(cartidgeWeapons.All(t => t.Cartidge == true));
}

ammunitionType = null;
using (var context = CreateContext())
{
var cartidgeWeapons = context.Weapons
.Where(w => w.AmmunitionType == ammunitionType)
.Select(w => new { w.Id, Cartidge = w.AmmunitionType == ammunitionType })
.ToList();

Assert.True(cartidgeWeapons.All(t => t.Cartidge == true));
}
}

[ConditionalFact]
public virtual void Select_ternary_operation_with_boolean()
{
using (var context = CreateContext())
{
var weapons = context.Weapons
.Select(w => new { w.Id, Num = w.IsAutomatic ? 1 : 0})
.ToList();

Assert.Equal(3, weapons.Count(w => w.Num == 1));
}
}

[ConditionalFact]
public virtual void Select_ternary_operation_with_inverted_boolean()
{
using (var context = CreateContext())
{
var weapons = context.Weapons
.Select(w => new { w.Id, Num = !w.IsAutomatic ? 1 : 0 })
.ToList();

Assert.Equal(7, weapons.Count(w => w.Num == 1));
}
}

[ConditionalFact]
public virtual void Select_Where_Navigation()
{
Expand Down
17 changes: 12 additions & 5 deletions test/EntityFramework.Core.FunctionalTests/QueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ public virtual void Where_select_many_and()
AssertQuery<Customer, Employee>((cs, es) =>
from c in cs
from e in es
// ReSharper disable ArrangeRedundantParentheses
// ReSharper disable ArrangeRedundantParentheses
where (c.City == "London" && c.Country == "UK")
&& (e.City == "London" && e.Country == "UK")
select new { c, e });
Expand Down Expand Up @@ -1597,6 +1597,13 @@ public virtual void Select_anonymous_bool_constant_in_expression()
cs => cs.Select(c => new { c.CustomerID, Expression = c.CustomerID.Length + 5 }));
}

[ConditionalFact]
public virtual void Select_anonymous_conditional_expression()
{
AssertQuery<Product>(
ps => ps.Select(p => new { p.ProductID, IsAvailable = p.UnitsInStock > 0 }));
}

[ConditionalFact]
public virtual void Select_customer_table()
{
Expand Down Expand Up @@ -4474,7 +4481,7 @@ public virtual void Select_take_skip_null_coalesce_operator2()
public virtual void Select_take_skip_null_coalesce_operator3()
{
AssertQuery<Customer>(
cs => cs.OrderBy(c => c.Region ?? "ZZ").Take(10).Skip(5),
cs => cs.OrderBy(c => c.Region ?? "ZZ").Take(10).Skip(5),
entryCount: 5);
}

Expand Down Expand Up @@ -4561,9 +4568,9 @@ public virtual void Select_Subquery_Single()
{
var orderDetails
= (from od in context.Set<OrderDetail>()
select (from o in context.Set<Order>()
where od.OrderID == o.OrderID
select o).First())
select (from o in context.Set<Order>()
where od.OrderID == o.OrderID
select o).First())
.Take(2)
.ToList();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,61 +66,71 @@ public static void Seed(GearsOfWarContext context)
var marcusLancer = new Weapon
{
Name = "Marcus' Lancer",
AmmunitionType = AmmunitionType.Cartridge
AmmunitionType = AmmunitionType.Cartridge,
IsAutomatic = true
};

var marcusGnasher = new Weapon
{
Name = "Marcus' Gnasher",
AmmunitionType = AmmunitionType.Shell,
IsAutomatic = false,
SynergyWith = marcusLancer
};

var domsHammerburst = new Weapon
{
Name = "Dom's Hammerburst",
AmmunitionType = AmmunitionType.Cartridge
AmmunitionType = AmmunitionType.Cartridge,
IsAutomatic = false
};

var domsGnasher = new Weapon
{
Name = "Dom's Gnasher",
AmmunitionType = AmmunitionType.Shell
AmmunitionType = AmmunitionType.Shell,
IsAutomatic = false
};

var colesGnasher = new Weapon
{
Name = "Cole's Gnasher",
AmmunitionType = AmmunitionType.Shell
AmmunitionType = AmmunitionType.Shell,
IsAutomatic = false
};

var colesMulcher = new Weapon
{
Name = "Cole's Mulcher",
AmmunitionType = AmmunitionType.Cartridge
AmmunitionType = AmmunitionType.Cartridge,
IsAutomatic = true
};

var bairdsLancer = new Weapon
{
Name = "Baird's Lancer",
AmmunitionType = AmmunitionType.Cartridge
AmmunitionType = AmmunitionType.Cartridge,
IsAutomatic = true
};

var bairdsGnasher = new Weapon
{
Name = "Baird's Gnasher",
AmmunitionType = AmmunitionType.Shell
AmmunitionType = AmmunitionType.Shell,
IsAutomatic = false
};

var paduksMarkza = new Weapon
{
Name = "Paduk's Markza",
AmmunitionType = AmmunitionType.Cartridge
AmmunitionType = AmmunitionType.Cartridge,
IsAutomatic = false
};

var maulersFlail = new Weapon
{
Name = "Mauler's Flail"
Name = "Mauler's Flail",
IsAutomatic = false
};

context.Weapons.Add(marcusLancer);
Expand Down Expand Up @@ -228,7 +238,7 @@ public static void Seed(GearsOfWarContext context)
};

var marcus = new Officer
{
{
Nickname = "Marcus",
FullName = "Marcus Fenix",
SquadId = deltaSquad.Id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public class Weapon
public int Id { get; set; }
public string Name { get; set; }
public AmmunitionType? AmmunitionType { get; set; }
public bool IsAutomatic { get; set; }

// 1 - 1 self reference
public int? SynergyWithId { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public override void From_sql_queryable_simple_columns_out_of_order()
base.From_sql_queryable_simple_columns_out_of_order();

Assert.Equal(
@"SELECT ""Id"", ""Name"", ""AmmunitionType"", ""OwnerFullName"", ""SynergyWithId"" FROM ""Weapon"" ORDER BY ""Name""",
@"SELECT ""Id"", ""Name"", ""IsAutomatic"", ""AmmunitionType"", ""OwnerFullName"", ""SynergyWithId"" FROM ""Weapon"" ORDER BY ""Name""",
Sql);
}

Expand Down
Loading

0 comments on commit 2837445

Please sign in to comment.