Skip to content

Commit

Permalink
Fix to #34056 - AOT/Query: for queries with JSON, interceptors genera…
Browse files Browse the repository at this point in the history
…te code with labels that are not uniquified (#34323)

Problem was that when we generate shaper for extracting json values from the reader (streaming) we use hard-coded name for a label that's used in the loop going through json properties. This is fine for normal scenarios, since compiler can figure out the right label to jump to based on scope/reference, but when we generate c# code from that expression, we run into issues for case of nested JSON.
Existing code already works if the label is not named explicitly - LinqToCSharpSyntaxTranslator generates names for unnamed labels and uniquifies the names when needed. Fix is to add the uniquification to all labels, not only the unnamed one.

Also fixed another small bug in LinqToCSharpSyntaxTranslator encountered in the testing - when we process LambdaExpression and are dealing with lifted statements (that need to be incorporated into the lambda body), we assumed it would always be a single expression, but it could also be a block. Fix is to add handling for the block case, which is simply pre-pending lifted statements to the ones already in the block.

Added JSON entites to Blog so that we can test the scenario, also added some posts (there were none) for cases which don't work with JSON entities, e.g. set ops.

Fixes #34056
  • Loading branch information
maumar authored Jul 31, 2024
1 parent b25cba7 commit ccf0247
Show file tree
Hide file tree
Showing 4 changed files with 351 additions and 133 deletions.
39 changes: 23 additions & 16 deletions src/EFCore.Design/Query/Internal/LinqToCSharpSyntaxTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ private sealed record StackFrame(
Dictionary<ParameterExpression, string> Variables,
HashSet<string> VariableNames,
Dictionary<LabelTarget, string> Labels,
HashSet<string> UnnamedLabelNames);
HashSet<string> UniqueLabelNames);

private readonly Stack<StackFrame> _stack
= new([new StackFrame([], [], [], [])]);
Expand Down Expand Up @@ -160,7 +160,7 @@ protected virtual SyntaxNode TranslateCore(
Check.DebugAssert(_stack.Peek().Variables.Count == 0, "_stack.Peek().Parameters.Count == 0");
Check.DebugAssert(_stack.Peek().VariableNames.Count == 0, "_stack.Peek().ParameterNames.Count == 0");
Check.DebugAssert(_stack.Peek().Labels.Count == 0, "_stack.Peek().Labels.Count == 0");
Check.DebugAssert(_stack.Peek().UnnamedLabelNames.Count == 0, "_stack.Peek().UnnamedLabelNames.Count == 0");
Check.DebugAssert(_stack.Peek().UniqueLabelNames.Count == 0, "_stack.Peek().UniqueLabelNames.Count == 0");

foreach (var unsafeAccessor in _fieldUnsafeAccessors.Values.Concat(_methodUnsafeAccessors.Values))
{
Expand Down Expand Up @@ -714,6 +714,8 @@ static bool IsExpressionValidAsStatement(ExpressionSyntax expression)
void PreprocessLabels()
{
// LINQ label targets can be unnamed, so we need to generate names for unnamed ones and maintain a target->name mapping.
// Also labels can have duplicated names - we need to de-duplicate them before we can generate a valid c# code
// just like we do with variables/parameters
// We need to maintain this as a stack for every block which has labels.
// Normal blocks get their own labels stack frame, which gets popped when we leave the block. Expression labels add their
// labels to their parent's stack frame (since they get lifted).
Expand All @@ -726,21 +728,17 @@ void PreprocessLabels()
continue;
}

var (_, _, labels, unnamedLabelNames) = stackFrame;
var (_, _, labels, uniqueLabelNames) = stackFrame;

// Generate names for unnamed label targets and uniquify
// Generate names for unnamed label targets and uniquify (all label names)
identifier = label.Target.Name ?? "unnamedLabel";
var identifierBase = identifier;
for (var i = 0; unnamedLabelNames.Contains(identifier); i++)
for (var i = 0; uniqueLabelNames.Contains(identifier); i++)
{
identifier = identifierBase + i;
}

if (label.Target.Name is null)
{
unnamedLabelNames.Add(identifier);
}

uniqueLabelNames.Add(identifier);
labels.Add(label.Target, identifier);
}
}
Expand Down Expand Up @@ -1507,15 +1505,24 @@ protected override Expression VisitLambda<T>(Expression<T> lambda)
var expressionBody = body as ExpressionSyntax;
var blockBody = body as BlockSyntax;

// If the lambda body was an expression that had lifted statements (e.g. some block in expression context), we need to create
// a block to contain these statements
if (_liftedState.Statements.Count > 0)
{
Check.DebugAssert(lambda.ReturnType != typeof(void), "lambda.ReturnType != typeof(void)");
Check.DebugAssert(expressionBody != null, "expressionBody != null");

blockBody = Block(_liftedState.Statements.Append(ReturnStatement(expressionBody)));
expressionBody = null;
if (expressionBody != null)
{
// If the lambda body was an expression that had lifted statements (e.g. some block in expression context), we need to create
// a block to contain these statements
blockBody = Block(_liftedState.Statements.Append(ReturnStatement(expressionBody)));
expressionBody = null;
}
else
{
// If the lambda body was already a block, we just prepend lifted statements to the ones already existing in the block
Check.DebugAssert(blockBody != null, "expressionBody != null || blockBody != null");
blockBody = Block(_liftedState.Statements.Concat(blockBody.Statements));
}

_liftedState.Statements.Clear();
}

Expand Down Expand Up @@ -2734,7 +2741,7 @@ private StackFrame PushNewStackFrame()
new Dictionary<ParameterExpression, string>(previousFrame.Variables),
[..previousFrame.VariableNames],
new Dictionary<LabelTarget, string>(previousFrame.Labels),
[..previousFrame.UnnamedLabelNames]);
[..previousFrame.UniqueLabelNames]);

_stack.Push(newFrame);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
using Microsoft.EntityFrameworkCore.Query.Internal;
using static Microsoft.EntityFrameworkCore.TestUtilities.PrecompiledQueryTestHelpers;
using Blog = Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase.Blog;
using Post = Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase.Post;
using JsonRoot = Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase.JsonRoot;
using JsonBranch = Microsoft.EntityFrameworkCore.Query.PrecompiledQueryRelationalTestBase.JsonBranch;
namespace Microsoft.EntityFrameworkCore.Query;

public abstract class PrecompiledQueryRelationalFixture
Expand All @@ -27,9 +30,26 @@ protected override IServiceCollection AddServices(IServiceCollection serviceColl

protected override async Task SeedAsync(PrecompiledQueryRelationalTestBase.PrecompiledQueryContext context)
{
context.Blogs.AddRange(
new Blog { Id = 8, Name = "Blog1" },
new Blog { Id = 9, Name = "Blog2" });
var blog1 = new Blog { Id = 8, Name = "Blog1", Json = [] };
var blog2 = new Blog
{
Id = 9,
Name = "Blog2",
Json =
[
new JsonRoot { Number = 1, Text = "One", Inner = new JsonBranch { Date = new DateTime(2001, 1, 1) } },
new JsonRoot { Number = 2, Text = "Two", Inner = new JsonBranch { Date = new DateTime(2002, 2, 2) } },
]};

context.Blogs.AddRange(blog1, blog2);

var post11 = new Post { Id = 11, Title = "Post11", Blog = blog1 };
var post12 = new Post { Id = 12, Title = "Post12", Blog = blog1 };
var post21 = new Post { Id = 21, Title = "Post21", Blog = blog2 };
var post22 = new Post { Id = 22, Title = "Post22", Blog = blog2 };
var post23 = new Post { Id = 23, Title = "Post23", Blog = blog2 };

context.Posts.AddRange(post11, post12, post21, post22, post23);
await context.SaveChangesAsync();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ public virtual Task BinaryExpression()
=> Test("""
var id = 3;
var blogs = await context.Blogs.Where(b => b.Id > id).ToListAsync();

Assert.Equal(2, blogs.Count);
var orderedBlogs = blogs.OrderBy(x => x.Id).ToList();
var blog1 = orderedBlogs[0];
var blog2 = orderedBlogs[1];

Assert.Equal(8, blog1.Id);
Assert.Equal("Blog1", blog1.Name);
Assert.Empty(blog1.Json);

Assert.Equal(9, blog2.Id);
Assert.Equal("Blog2", blog2.Name);
Assert.Equal(2, blog2.Json.Count);

Assert.Equal(1, blog2.Json[0].Number);
Assert.Equal("One", blog2.Json[0].Text);
Assert.Equal(new DateTime(2001, 1, 1), blog2.Json[0].Inner.Date);

Assert.Equal(2, blog2.Json[1].Number);
Assert.Equal("Two", blog2.Json[1].Text);
Assert.Equal(new DateTime(2002, 2, 2), blog2.Json[1].Inner.Date);
""");

[ConditionalFact]
Expand Down Expand Up @@ -729,6 +750,23 @@ public virtual Task Terminating_ExecuteUpdateAsync()
public virtual Task Union()
=> Test(
"""
var posts = await context.Posts.Where(p => p.Id > 11)
.Union(context.Posts.Where(p => p.Id < 21))
.OrderBy(p => p.Id)
.ToListAsync();

Assert.Collection(posts,
b => Assert.Equal(11, b.Id),
b => Assert.Equal(12, b.Id),
b => Assert.Equal(21, b.Id),
b => Assert.Equal(22, b.Id),
b => Assert.Equal(23, b.Id));
""");

[ConditionalFact(Skip = "issue 33378")]
public virtual Task UnionOnEntitiesWithJson()
=> Test(
"""
var blogs = await context.Blogs.Where(b => b.Id > 7)
.Union(context.Blogs.Where(b => b.Id < 10))
.OrderBy(b => b.Id)
Expand All @@ -743,6 +781,24 @@ public virtual Task Union()
public virtual Task Concat()
=> Test(
"""
var posts = await context.Posts.Where(p => p.Id > 11)
.Concat(context.Posts.Where(p => p.Id < 21))
.OrderBy(p => p.Id)
.ToListAsync();

Assert.Collection(posts,
b => Assert.Equal(11, b.Id),
b => Assert.Equal(12, b.Id),
b => Assert.Equal(12, b.Id),
b => Assert.Equal(21, b.Id),
b => Assert.Equal(22, b.Id),
b => Assert.Equal(23, b.Id));
""");

[ConditionalFact(Skip = "issue 33378")]
public virtual Task ConcatOnEntitiesWithJson()
=> Test(
"""
var blogs = await context.Blogs.Where(b => b.Id > 7)
.Concat(context.Blogs.Where(b => b.Id < 10))
.OrderBy(b => b.Id)
Expand All @@ -759,6 +815,20 @@ public virtual Task Concat()
public virtual Task Intersect()
=> Test(
"""
var posts = await context.Posts.Where(b => b.Id > 11)
.Intersect(context.Posts.Where(b => b.Id < 22))
.OrderBy(b => b.Id)
.ToListAsync();

Assert.Collection(posts,
b => Assert.Equal(12, b.Id),
b => Assert.Equal(21, b.Id));
""");

[ConditionalFact(Skip = "issue 33378")]
public virtual Task IntersectOnEntitiesWithJson()
=> Test(
"""
var blogs = await context.Blogs.Where(b => b.Id > 7)
.Intersect(context.Blogs.Where(b => b.Id > 8))
.OrderBy(b => b.Id)
Expand All @@ -771,6 +841,20 @@ public virtual Task Intersect()
public virtual Task Except()
=> Test(
"""
var posts = await context.Posts.Where(b => b.Id > 11)
.Except(context.Posts.Where(b => b.Id > 21))
.OrderBy(b => b.Id)
.ToListAsync();

Assert.Collection(posts,
b => Assert.Equal(12, b.Id),
b => Assert.Equal(21, b.Id));
""");

[ConditionalFact(Skip = "issue 33378")]
public virtual Task ExceptOnEntitiesWithJson()
=> Test(
"""
var blogs = await context.Blogs.Where(b => b.Id > 7)
.Except(context.Blogs.Where(b => b.Id > 8))
.OrderBy(b => b.Id)
Expand Down Expand Up @@ -1066,6 +1150,20 @@ public class PrecompiledQueryContext(DbContextOptions options) : DbContext(optio
{
public DbSet<Blog> Blogs { get; set; } = null!;
public DbSet<Post> Posts { get; set; } = null!;

protected override void OnModelCreating(ModelBuilder modelBuilder)
{
base.OnModelCreating(modelBuilder);
modelBuilder.Entity<Blog>().OwnsMany(
x => x.Json,
n =>
{
n.ToJson();
n.OwnsOne(xx => xx.Inner);
});
modelBuilder.Entity<Blog>().HasMany(x => x.Posts).WithOne(x => x.Blog).OnDelete(DeleteBehavior.Cascade);
modelBuilder.Entity<Post>().Property(x => x.Id).ValueGeneratedNever();
}
}

protected PrecompiledQueryRelationalFixture Fixture { get; }
Expand Down Expand Up @@ -1128,8 +1226,21 @@ public Blog(int id, string name)
[DatabaseGenerated(DatabaseGeneratedOption.None)]
public int Id { get; set; }
public string? Name { get; set; }

public List<Post> Posts { get; set; } = new();
public List<JsonRoot> Json { get; set; } = new();
}

public class JsonRoot
{
public int Number { get; set; }
public string? Text { get; set; }

public JsonBranch Inner { get; set; } = null!;
}

public class JsonBranch
{
public DateTime Date { get; set; }
}

public class Post
Expand Down
Loading

0 comments on commit ccf0247

Please sign in to comment.