diff --git a/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework.Tests/DomainEntityTests.cs b/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework.Tests/DomainEntityTests.cs index d6af488..0c4db46 100644 --- a/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework.Tests/DomainEntityTests.cs +++ b/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework.Tests/DomainEntityTests.cs @@ -11,7 +11,7 @@ namespace Codehard.Infrastructure.EntityFramework.Tests; public class DomainEntityTests { [Fact] - public async Task WhenSaveChanges_UsingDomainEventDbContext_ShouldPublishAndClearEvents() + public async Task WhenSaveChangesAsync_UsingDomainEventDbContext_ShouldPublishAndClearEvents() { // Arrange var assembly = Assembly.GetExecutingAssembly(); @@ -59,7 +59,7 @@ public async Task WhenSaveChanges_UsingDomainEventDbContext_ShouldPublishAndClea } [Fact] - public async Task WhenSaveChanges_UsingGlobalPublisherFunction_ShouldPublishAndClearEvents() + public async Task WhenSaveChangesAsync_UsingGlobalPublisherFunction_ShouldPublishAndClearEvents() { // Arrange var assembly = Assembly.GetExecutingAssembly(); @@ -113,6 +113,109 @@ public async Task WhenSaveChanges_UsingGlobalPublisherFunction_ShouldPublishAndC Assert.Empty(entity.Events); } + [Fact] + public void WhenSaveChanges_UsingDomainEventDbContext_ShouldPublishAndClearEvents() + { + // Arrange + var assembly = Assembly.GetExecutingAssembly(); + var options = new DbContextOptionsBuilder() + .UseSqlite(CreateInMemoryDatabase()) + .Options; + + var loggerMock = new Mock>(); + var logger = loggerMock.Object; + using var context = new TestDbContext( + options, + builder => builder.ApplyConfigurationsFromAssemblyFor(assembly), + logger); + context.Database.EnsureCreated(); + + // Act + var entity = EntityA.Create(); + entity.UpdateValue("New Value"); + + context.As.Add(entity); + context.SaveChanges(); + + // Assert + loggerMock.Verify( + logger => + logger.Log( + It.Is(logLevel => logLevel == LogLevel.Information), + It.Is(eventId => true), + It.Is((@object, @type) => + @object.ToString()!.Contains(nameof(EntityCreatedEvent))), + It.IsAny(), + It.IsAny>()), + Times.Once); + loggerMock.Verify( + logger => + logger.Log( + It.Is(logLevel => logLevel == LogLevel.Information), + It.Is(eventId => true), + It.Is((@object, @type) => + @object.ToString()!.Contains(nameof(ValueChangedEvent))), + It.IsAny(), + It.IsAny>()), + Times.Once); + Assert.Empty(entity.Events); + } + + [Fact] + public void WhenSaveChanges_UsingGlobalPublisherFunction_ShouldPublishAndClearEvents() + { + // Arrange + var assembly = Assembly.GetExecutingAssembly(); + var options = new DbContextOptionsBuilder() + .UseSqlite(CreateInMemoryDatabase()) + .Options; + + var loggerMock = new Mock>(); + var logger = loggerMock.Object; + using var context = new TestDbContext( + options, + builder => builder.ApplyConfigurationsFromAssemblyFor(assembly), + dm => + { + // We use LogWarning here to distinct + // between the global and local (within the DbContext) publisher + logger.LogWarning(dm.ToString()); + + return Task.CompletedTask; + }); + context.Database.EnsureCreated(); + + // Act + var entity = EntityA.Create(); + entity.UpdateValue("New Value"); + + context.As.Add(entity); + context.SaveChanges(); + + // Assert + loggerMock.Verify( + logger => + logger.Log( + It.Is(logLevel => logLevel == LogLevel.Warning), + It.Is(eventId => true), + It.Is((@object, @type) => + @object.ToString()!.Contains(nameof(EntityCreatedEvent))), + It.IsAny(), + It.IsAny>()), + Times.Once); + loggerMock.Verify( + logger => + logger.Log( + It.Is(logLevel => logLevel == LogLevel.Warning), + It.Is(eventId => true), + It.Is((@object, @type) => + @object.ToString()!.Contains(nameof(ValueChangedEvent))), + It.IsAny(), + It.IsAny>()), + Times.Once); + Assert.Empty(entity.Events); + } + private static SqliteConnection CreateInMemoryDatabase() { var connection = new SqliteConnection("DataSource=:memory:"); diff --git a/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework.Tests/Entities/EntityA.cs b/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework.Tests/Entities/EntityA.cs index 2aba43d..f104fd0 100644 --- a/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework.Tests/Entities/EntityA.cs +++ b/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework.Tests/Entities/EntityA.cs @@ -10,7 +10,7 @@ public class EntityCreatedEvent : IDomainEvent public record ValueChangedEvent(EntityAKey Id, string NewValue) : IDomainEvent; -public class EntityAKey : IEntityKey +public struct EntityAKey { public Guid Value { get; set; } } diff --git a/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework/Codehard.Infrastructure.EntityFramework.csproj b/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework/Codehard.Infrastructure.EntityFramework.csproj index 6d743a1..8e5c815 100644 --- a/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework/Codehard.Infrastructure.EntityFramework.csproj +++ b/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework/Codehard.Infrastructure.EntityFramework.csproj @@ -4,12 +4,13 @@ net6.0 enable enable - 3.0.0-preview-2 + 3.0.0-preview-3 A library contains common code related to Entity Framework Core. https://github.com/codehardth/Codehard.Common https://github.com/codehardth/Codehard.Common Added a foundation for domain event. true + true @@ -23,11 +24,11 @@ - + - + diff --git a/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework/Interceptors/DomainEventPublisherInterceptor.cs b/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework/Interceptors/DomainEventPublisherInterceptor.cs index 4f37888..4cbd2f9 100644 --- a/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework/Interceptors/DomainEventPublisherInterceptor.cs +++ b/src/Codehard.Infrastructure/Codehard.Infrastructure.EntityFramework/Interceptors/DomainEventPublisherInterceptor.cs @@ -2,6 +2,8 @@ using System.Collections.Concurrent; using System.Reflection; using Codehard.Common.DomainModel; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.ChangeTracking; using Microsoft.EntityFrameworkCore.Diagnostics; namespace Codehard.Infrastructure.EntityFramework.Interceptors; @@ -30,6 +32,36 @@ public DomainEventPublisherInterceptor(PublishDomainEventDelegate? publisher = d this.publisher = publisher; } + public override int SavedChanges(SaveChangesCompletedEventData eventData, int result) + { + var context = eventData.Context; + + if (context is null) + { + return result; + } + + var publisherFunc = GetPublisherFunction(context); + + if (publisherFunc is null) + { + return result; + } + + var domainEntities = RetrieveDomainEntities(context); + var events = GetEvents(domainEntities); + + foreach (var @event in events) + { + publisherFunc(@event).Wait(); + } + + ClearExistingEvents(domainEntities); + + return result; + } + + public override async ValueTask> SavingChangesAsync( DbContextEventData eventData, InterceptionResult result, @@ -39,66 +71,84 @@ public override async ValueTask> SavingChangesAsync( if (context is null) { - return new InterceptionResult(); + return result; } + var publisherFunc = GetPublisherFunction(context); + + if (publisherFunc is null) + { + return result; + } + + var domainEntities = RetrieveDomainEntities(context); + var events = GetEvents(domainEntities); + + foreach (var @event in events) + { + await publisherFunc(@event); + } + + ClearExistingEvents(domainEntities); + + return result; + } + + private PublishDomainEventDelegate? GetPublisherFunction(DbContext context) + { var publisherFunc = this.publisher ?? (context is IDomainEventDbContext domainEventDbContext ? domainEventDbContext.PublishDomainEventAsync : null); + return publisherFunc; + } - if (publisherFunc is null) - { - return new InterceptionResult(); - } - + private static EntityEntry[] RetrieveDomainEntities(DbContext context) + { // Retrieve domain entities from the change tracker. var domainEntities = context.ChangeTracker .Entries() .Where(e => e.Entity.GetType().IsAssignableTo(typeof(IEntity))) .ToArray(); + return domainEntities; + } + private static IEnumerable GetEvents(EntityEntry[] domainEntities) + { // Extract and publish domain events. var events = domainEntities.SelectMany(e => ((IEntity)e.Entity).Events); + return events; + } - foreach (var @event in events) - { - await publisherFunc(@event); - } - + private static void ClearExistingEvents(EntityEntry[] domainEntities) + { // Clear all existing events foreach (var entityEntry in domainEntities) { var type = entityEntry.Entity.GetType(); var eventsFieldInfo = TryGetEventsFieldInfo(type); - if (eventsFieldInfo?.GetValue(entityEntry.Entity) is IList list) + if (eventsFieldInfo.GetValue(entityEntry.Entity) is IList list) { list.Clear(); } } - return new InterceptionResult(); - - static FieldInfo? TryGetEventsFieldInfo(Type type) + static FieldInfo TryGetEventsFieldInfo(Type type) { if (FieldInfoCache.TryGetValue(type, out var fi)) { return fi; } - var baseType = TraverseBaseEntityType(type); - - if (baseType is null) - { - return default; - } + var fields = TraverseBackingFields(type); var fieldInfo = - baseType.GetField("events", BindingFlags.Instance | BindingFlags.NonPublic) - ?? throw new FieldAccessException(); + fields.FirstOrDefault(f => + f.Name is "events" or "_events" && f.FieldType.IsAssignableTo(typeof(IList))) + ?? throw new FieldAccessException("Unable to find backing field 'events' or '_events'."); FieldInfoCache.AddOrUpdate( type, @@ -108,23 +158,20 @@ public override async ValueTask> SavingChangesAsync( return fieldInfo; } - static Type? TraverseBaseEntityType(Type? type) + static IEnumerable TraverseBackingFields(Type type) { - while (true) - { - var baseType = type?.BaseType; + var currentType = type; - if (baseType is null) - { - return null; - } + while (currentType is not null) + { + var fields = currentType.GetFields(BindingFlags.Instance | BindingFlags.NonPublic); - if (baseType.IsGenericType && baseType.GetGenericTypeDefinition() == typeof(Entity<>)) + foreach (var field in fields) { - return baseType; + yield return field; } - type = baseType; + currentType = currentType.BaseType; } } }