Skip to content

Commit

Permalink
Improve domain event publisher implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
incompetent-developer committed Nov 20, 2023
1 parent b090ff8 commit 1eb7782
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Codehard.Infrastructure.EntityFramework.Tests;
public class DomainEntityTests
{
[Fact]
public async Task WhenSaveChanges_UsingDomainEventDbContext_ShouldPublishAndClearEvents()
public async Task WhenSaveChangesAsync_UsingDomainEventDbContext_ShouldPublishAndClearEvents()
{
// Arrange
var assembly = Assembly.GetExecutingAssembly();
Expand Down Expand Up @@ -59,7 +59,7 @@ public async Task WhenSaveChanges_UsingDomainEventDbContext_ShouldPublishAndClea
}

[Fact]
public async Task WhenSaveChanges_UsingGlobalPublisherFunction_ShouldPublishAndClearEvents()
public async Task WhenSaveChangesAsync_UsingGlobalPublisherFunction_ShouldPublishAndClearEvents()
{
// Arrange
var assembly = Assembly.GetExecutingAssembly();
Expand Down Expand Up @@ -113,6 +113,109 @@ public async Task WhenSaveChanges_UsingGlobalPublisherFunction_ShouldPublishAndC
Assert.Empty(entity.Events);
}

[Fact]
public void WhenSaveChanges_UsingDomainEventDbContext_ShouldPublishAndClearEvents()
{
// Arrange
var assembly = Assembly.GetExecutingAssembly();
var options = new DbContextOptionsBuilder<TestDbContext>()
.UseSqlite(CreateInMemoryDatabase())
.Options;

var loggerMock = new Mock<ILogger<TestDbContext>>();
var logger = loggerMock.Object;
using var context = new TestDbContext(
options,
builder => builder.ApplyConfigurationsFromAssemblyFor<TestDbContext>(assembly),
logger);
context.Database.EnsureCreated();

// Act
var entity = EntityA.Create();
entity.UpdateValue("New Value");

context.As.Add(entity);
context.SaveChanges();

// Assert
loggerMock.Verify(
logger =>
logger.Log(
It.Is<LogLevel>(logLevel => logLevel == LogLevel.Information),
It.Is<EventId>(eventId => true),
It.Is<It.IsAnyType>((@object, @type) =>
@object.ToString()!.Contains(nameof(EntityCreatedEvent))),
It.IsAny<Exception>(),
It.IsAny<Func<It.IsAnyType, Exception?, string>>()),
Times.Once);
loggerMock.Verify(
logger =>
logger.Log(
It.Is<LogLevel>(logLevel => logLevel == LogLevel.Information),
It.Is<EventId>(eventId => true),
It.Is<It.IsAnyType>((@object, @type) =>
@object.ToString()!.Contains(nameof(ValueChangedEvent))),
It.IsAny<Exception>(),
It.IsAny<Func<It.IsAnyType, Exception?, string>>()),
Times.Once);
Assert.Empty(entity.Events);
}

[Fact]
public void WhenSaveChanges_UsingGlobalPublisherFunction_ShouldPublishAndClearEvents()
{
// Arrange
var assembly = Assembly.GetExecutingAssembly();
var options = new DbContextOptionsBuilder<TestDbContext>()
.UseSqlite(CreateInMemoryDatabase())
.Options;

var loggerMock = new Mock<ILogger<TestDbContext>>();
var logger = loggerMock.Object;
using var context = new TestDbContext(
options,
builder => builder.ApplyConfigurationsFromAssemblyFor<TestDbContext>(assembly),
dm =>
{
// We use LogWarning here to distinct
// between the global and local (within the DbContext) publisher
logger.LogWarning(dm.ToString());

return Task.CompletedTask;
});
context.Database.EnsureCreated();

// Act
var entity = EntityA.Create();
entity.UpdateValue("New Value");

context.As.Add(entity);
context.SaveChanges();

// Assert
loggerMock.Verify(
logger =>
logger.Log(
It.Is<LogLevel>(logLevel => logLevel == LogLevel.Warning),
It.Is<EventId>(eventId => true),
It.Is<It.IsAnyType>((@object, @type) =>
@object.ToString()!.Contains(nameof(EntityCreatedEvent))),
It.IsAny<Exception>(),
It.IsAny<Func<It.IsAnyType, Exception?, string>>()),
Times.Once);
loggerMock.Verify(
logger =>
logger.Log(
It.Is<LogLevel>(logLevel => logLevel == LogLevel.Warning),
It.Is<EventId>(eventId => true),
It.Is<It.IsAnyType>((@object, @type) =>
@object.ToString()!.Contains(nameof(ValueChangedEvent))),
It.IsAny<Exception>(),
It.IsAny<Func<It.IsAnyType, Exception?, string>>()),
Times.Once);
Assert.Empty(entity.Events);
}

private static SqliteConnection CreateInMemoryDatabase()
{
var connection = new SqliteConnection("DataSource=:memory:");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public class EntityCreatedEvent : IDomainEvent<EntityAKey>
public record ValueChangedEvent(EntityAKey Id, string NewValue)
: IDomainEvent<EntityAKey>;

public class EntityAKey : IEntityKey
public struct EntityAKey
{
public Guid Value { get; set; }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
<TargetFramework>net6.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<Version>3.0.0-preview-2</Version>
<Version>3.0.0-preview-3</Version>
<Description>A library contains common code related to Entity Framework Core.</Description>
<PackageProjectUrl>https://github.com/codehardth/Codehard.Common</PackageProjectUrl>
<RepositoryUrl>https://github.com/codehardth/Codehard.Common</RepositoryUrl>
<PackageReleaseNotes>Added a foundation for domain event.</PackageReleaseNotes>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
</PropertyGroup>

<ItemGroup>
Expand All @@ -23,11 +24,11 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\Codehard.Common\Codehard.Common.DomainModel\Codehard.Common.DomainModel.csproj" />
<InternalsVisibleTo Include="Codehard.Infrastructure.EntityFramework.Tests"/>
</ItemGroup>

<ItemGroup>
<InternalsVisibleTo Include="Codehard.Infrastructure.EntityFramework.Tests"/>
<ProjectReference Include="..\..\Codehard.Common\Codehard.Common.DomainModel\Codehard.Common.DomainModel.csproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
using System.Collections.Concurrent;
using System.Reflection;
using Codehard.Common.DomainModel;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Diagnostics;

namespace Codehard.Infrastructure.EntityFramework.Interceptors;
Expand Down Expand Up @@ -30,6 +32,36 @@ public DomainEventPublisherInterceptor(PublishDomainEventDelegate? publisher = d
this.publisher = publisher;
}

public override int SavedChanges(SaveChangesCompletedEventData eventData, int result)
{
var context = eventData.Context;

if (context is null)
{
return result;
}

var publisherFunc = GetPublisherFunction(context);

if (publisherFunc is null)
{
return result;
}

var domainEntities = RetrieveDomainEntities(context);
var events = GetEvents(domainEntities);

foreach (var @event in events)
{
publisherFunc(@event).Wait();
}

ClearExistingEvents(domainEntities);

return result;
}


public override async ValueTask<InterceptionResult<int>> SavingChangesAsync(
DbContextEventData eventData,
InterceptionResult<int> result,
Expand All @@ -39,66 +71,84 @@ public override async ValueTask<InterceptionResult<int>> SavingChangesAsync(

if (context is null)
{
return new InterceptionResult<int>();
return result;
}

var publisherFunc = GetPublisherFunction(context);

if (publisherFunc is null)
{
return result;
}

var domainEntities = RetrieveDomainEntities(context);
var events = GetEvents(domainEntities);

foreach (var @event in events)
{
await publisherFunc(@event);
}

ClearExistingEvents(domainEntities);

return result;
}

private PublishDomainEventDelegate? GetPublisherFunction(DbContext context)
{
var publisherFunc =
this.publisher ??
(context is IDomainEventDbContext domainEventDbContext
? domainEventDbContext.PublishDomainEventAsync
: null);
return publisherFunc;
}

if (publisherFunc is null)
{
return new InterceptionResult<int>();
}

private static EntityEntry[] RetrieveDomainEntities(DbContext context)
{
// Retrieve domain entities from the change tracker.
var domainEntities =
context.ChangeTracker
.Entries()
.Where(e => e.Entity.GetType().IsAssignableTo(typeof(IEntity)))
.ToArray();
return domainEntities;
}

private static IEnumerable<IDomainEvent> GetEvents(EntityEntry[] domainEntities)
{
// Extract and publish domain events.
var events = domainEntities.SelectMany(e => ((IEntity)e.Entity).Events);
return events;
}

foreach (var @event in events)
{
await publisherFunc(@event);
}

private static void ClearExistingEvents(EntityEntry[] domainEntities)
{
// Clear all existing events
foreach (var entityEntry in domainEntities)
{
var type = entityEntry.Entity.GetType();
var eventsFieldInfo = TryGetEventsFieldInfo(type);

if (eventsFieldInfo?.GetValue(entityEntry.Entity) is IList list)
if (eventsFieldInfo.GetValue(entityEntry.Entity) is IList list)
{
list.Clear();
}
}

return new InterceptionResult<int>();

static FieldInfo? TryGetEventsFieldInfo(Type type)
static FieldInfo TryGetEventsFieldInfo(Type type)
{
if (FieldInfoCache.TryGetValue(type, out var fi))
{
return fi;
}

var baseType = TraverseBaseEntityType(type);

if (baseType is null)
{
return default;
}
var fields = TraverseBackingFields(type);

var fieldInfo =
baseType.GetField("events", BindingFlags.Instance | BindingFlags.NonPublic)
?? throw new FieldAccessException();
fields.FirstOrDefault(f =>
f.Name is "events" or "_events" && f.FieldType.IsAssignableTo(typeof(IList)))
?? throw new FieldAccessException("Unable to find backing field 'events' or '_events'.");

FieldInfoCache.AddOrUpdate(
type,
Expand All @@ -108,23 +158,20 @@ public override async ValueTask<InterceptionResult<int>> SavingChangesAsync(
return fieldInfo;
}

static Type? TraverseBaseEntityType(Type? type)
static IEnumerable<FieldInfo> TraverseBackingFields(Type type)
{
while (true)
{
var baseType = type?.BaseType;
var currentType = type;

if (baseType is null)
{
return null;
}
while (currentType is not null)
{
var fields = currentType.GetFields(BindingFlags.Instance | BindingFlags.NonPublic);

if (baseType.IsGenericType && baseType.GetGenericTypeDefinition() == typeof(Entity<>))
foreach (var field in fields)
{
return baseType;
yield return field;
}

type = baseType;
currentType = currentType.BaseType;
}
}
}
Expand Down

0 comments on commit 1eb7782

Please sign in to comment.