diff --git a/src/Microsoft.DotNet.Interactive.Tests/KernelCommandSchedulerTests.cs b/src/Microsoft.DotNet.Interactive.Tests/KernelCommandSchedulerTests.cs index 96ab141c49..aa2437e22b 100644 --- a/src/Microsoft.DotNet.Interactive.Tests/KernelCommandSchedulerTests.cs +++ b/src/Microsoft.DotNet.Interactive.Tests/KernelCommandSchedulerTests.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using FluentAssertions; + using Microsoft.DotNet.Interactive.Tests.Utility; using Pocket; @@ -75,10 +76,10 @@ public async Task scheduled_work_does_not_execute_in_parallel() var concurrencyCounter = 0; var maxObservedParallelism = 0; var tasks = new Task[3]; - + for (var i = 0; i < 3; i++) { - var task = scheduler.Schedule(i, async _ => + var task = scheduler.Schedule(i, async _ => { Interlocked.Increment(ref concurrencyCounter); @@ -86,10 +87,10 @@ public async Task scheduled_work_does_not_execute_in_parallel() maxObservedParallelism = Math.Max(concurrencyCounter, maxObservedParallelism); Interlocked.Decrement(ref concurrencyCounter); - } ); + }); tasks[i] = task; } - + await Task.WhenAll(tasks); maxObservedParallelism.Should().Be(1); @@ -107,7 +108,7 @@ void PerformWork(int v) using var scheduler = new KernelScheduler(); scheduler.RegisterDeferredOperationSource( - (v,_) => Enumerable.Repeat(v * 10, v), PerformWork); + (v, _) => Enumerable.Repeat(v * 10, v), PerformWork); for (var i = 1; i <= 3; i++) { @@ -144,6 +145,132 @@ void PerformWork(int v) executionList.Should().BeEquivalentTo(1); } + [Fact] + public async Task cancelled_work_prevents_any_scheduled_work_from_executing() + { + var executionList = new List(); + using var scheduler = new KernelScheduler(); + var barrier = new Barrier(2); + + async Task PerformWork(int v) + { + barrier.SignalAndWait(); + await Task.Delay(3000); + executionList.Add(v); + } + + var scheduledWork = new List + { + scheduler.Schedule(1, PerformWork), + scheduler.Schedule(2, executionList.Add), + scheduler.Schedule(3, executionList.Add) + }; + + barrier.SignalAndWait(); + scheduler.Cancel(); + try + { + await Task.WhenAll(scheduledWork); + } + catch (TaskCanceledException) + { + + } + + executionList.Should().BeEmpty(); + } + + [Fact] + public void cancelling_work_throws_exception() + { + var executionList = new List(); + using var scheduler = new KernelScheduler(); + var barrier = new Barrier(2); + + async Task PerformWork(int v) + { + barrier.SignalAndWait(); + await Task.Delay(3000); + executionList.Add(v); + } + + var scheduledWork = new List + { + scheduler.Schedule(1, PerformWork), + scheduler.Schedule(2, executionList.Add), + scheduler.Schedule(3, executionList.Add) + }; + + barrier.SignalAndWait(); + scheduler.Cancel(); + var operation = new Action( () => Task.WhenAll(scheduledWork).Wait(5000)); + + operation.Should().Throw(); + } + + [Fact] + public async Task exception_in_scheduled_work_halts_execution() + { + var executionList = new List(); + using var scheduler = new KernelScheduler(); + var barrier = new Barrier(2); + + void PerformWork(int v) + { + barrier.SignalAndWait(); + throw new InvalidOperationException("test exception"); + } + + var scheduledWork = new List + { + scheduler.Schedule(1, PerformWork), + scheduler.Schedule(2, executionList.Add), + scheduler.Schedule(3, executionList.Add) + }; + + barrier.SignalAndWait(); + try + { + await Task.WhenAll(scheduledWork); + } + catch(InvalidOperationException) + { + + } + + executionList.Should().BeEmpty(); + } + + [Fact] + public void exception_in_scheduled_work_is_propagated() + { + var executionList = new List(); + using var scheduler = new KernelScheduler(); + var barrier = new Barrier(2); + + void PerformWork(int v) + { + barrier.SignalAndWait(); + throw new InvalidOperationException("test exception"); + } + + var scheduledWork = new List + { + scheduler.Schedule(1, PerformWork), + scheduler.Schedule(2, executionList.Add), + scheduler.Schedule(3, executionList.Add) + }; + + barrier.SignalAndWait(); + var operation = new Action(() => Task.WhenAll(scheduledWork).Wait(5000)); + + operation.Should().Throw() + .Which + .Message + .Should() + .Be("test exception"); + } + [Fact] public async Task awaiting_for_work_to_complete_does_not_wait_for_subsequent_work() { @@ -162,9 +289,9 @@ async Task PerformWorkAsync(int v) _ = scheduler.Schedule(3, PerformWorkAsync); - executionList.Should().BeEquivalentSequenceTo( 1, 2); + executionList.Should().BeEquivalentSequenceTo(1, 2); + - } [Fact] @@ -186,168 +313,7 @@ void PerformWork(int v) await scheduler.Schedule(i, PerformWork, $"scope{i}"); } - executionList.Should().BeEquivalentSequenceTo( 1, 20, 20, 2,3); + executionList.Should().BeEquivalentSequenceTo(1, 20, 20, 2, 3); } - - //[Fact] - //public async Task command_execute_on_kernel_specified_at_scheduling_time() - //{ - // var commandsHandledOnKernel1 = new List(); - // var commandsHandledOnKernel2 = new List(); - - // var scheduler = new KernelCommandScheduler(); - - // var kernel1 = new FakeKernel("kernel1") - // { - // Handle = (command, _) => - // { - // commandsHandledOnKernel1.Add(command); - // return Task.CompletedTask; - // } - // }; - // var kernel2 = new FakeKernel("kernel2") - // { - // Handle = (command, _) => - // { - // commandsHandledOnKernel2.Add(command); - // return Task.CompletedTask; - // } - // }; - - // var command1 = new SubmitCode("for kernel 1", kernel1.Name); - // var command2 = new SubmitCode("for kernel 2", kernel2.Name); - - // await scheduler.Schedule(command1); - // await scheduler.Schedule(command2); - - // commandsHandledOnKernel1.Should().ContainSingle().Which.Should().Be(command1); - // commandsHandledOnKernel2.Should().ContainSingle().Which.Should().Be(command2); - //} - - //[Fact] - //public async Task scheduling_a_command_will_defer_deferred_commands_scheduled_on_same_kernel() - //{ - // var commandsHandledOnKernel1 = new List(); - - // var scheduler = new KernelCommandScheduler(); - - // var kernel1 = new FakeKernel("kernel1") - // { - // Handle = (command, _) => - // { - // commandsHandledOnKernel1.Add(command); - // return Task.CompletedTask; - // } - // }; - // var kernel2 = new FakeKernel("kernel2") - // { - // Handle = (_, _) => Task.CompletedTask - // }; - - // var deferredCommand1 = new SubmitCode("deferred for kernel 1", kernel1.Name); - // var deferredCommand2 = new SubmitCode("deferred for kernel 2", kernel2.Name); - // var deferredCommand3 = new SubmitCode("deferred for kernel 1", kernel1.Name); - // var command1 = new SubmitCode("for kernel 1", kernel1.Name); - - // scheduler.DeferCommand(deferredCommand1); - // scheduler.DeferCommand(deferredCommand2); - // scheduler.DeferCommand(deferredCommand3); - // await scheduler.Schedule(command1); - - // commandsHandledOnKernel1.Should().NotContain(deferredCommand2); - // commandsHandledOnKernel1.Should().BeEquivalentSequenceTo(deferredCommand1, deferredCommand3, command1); - //} - - //[Fact] - //public async Task deferred_command_not_executed_are_still_in_deferred_queue() - //{ - // var commandsHandledOnKernel1 = new List(); - // var commandsHandledOnKernel2 = new List(); - - // var scheduler = new KernelCommandScheduler(); - - // var kernel1 = new FakeKernel("kernel1") - // { - // Handle = (command, _) => - // { - // commandsHandledOnKernel1.Add(command); - // return Task.CompletedTask; - // } - // }; - // var kernel2 = new FakeKernel("kernel2") - // { - // Handle = (command, _) => - // { - // commandsHandledOnKernel2.Add(command); - // return Task.CompletedTask; - // } - // }; - - // var deferredCommand1 = new SubmitCode("deferred for kernel 1"); - // var deferredCommand2 = new SubmitCode("deferred for kernel 2"); - // var deferredCommand3 = new SubmitCode("deferred for kernel 1"); - // var command1 = new SubmitCode("for kernel 1"); - // var command2 = new SubmitCode("for kernel 2"); - - // scheduler.DeferCommand(deferredCommand1, kernel1); - // scheduler.DeferCommand(deferredCommand2, kernel2); - // scheduler.DeferCommand(deferredCommand3, kernel1); - // await scheduler.Schedule(command1, kernel1); - - // commandsHandledOnKernel2.Should().BeEmpty(); - // commandsHandledOnKernel1.Should().NotContain(deferredCommand2); - // commandsHandledOnKernel1.Should().BeEquivalentSequenceTo(deferredCommand1, deferredCommand3, command1); - // await scheduler.Schedule(command2, kernel2); - // commandsHandledOnKernel2.Should().BeEquivalentSequenceTo(deferredCommand2, command2); - //} - - //[Fact] - //public async Task deferred_command_on_parent_kernel_are_executed_when_scheduling_command_on_child_kernel() - //{ - // var commandHandledList = new List<(KernelCommand command, Kernel kernel)>(); - - // var scheduler = new KernelCommandScheduler(); - - // var childKernel = new FakeKernel("kernel1") - // { - // Handle = (command, context) => command.InvokeAsync(context) - // }; - // var parentKernel = new CompositeKernel - // { - // childKernel - // }; - - // parentKernel.DefaultKernelName = childKernel.Name; - - // var deferredCommand1 = new TestCommand((command, context) => - // { - // commandHandledList.Add((command, context.HandlingKernel)); - // return Task.CompletedTask; - // }, childKernel.Name); - // var deferredCommand2 = new TestCommand((command, context) => - // { - // commandHandledList.Add((command, context.HandlingKernel)); - // return Task.CompletedTask; - // }, parentKernel.Name); - // var deferredCommand3 = new TestCommand((command, context) => - // { - // commandHandledList.Add((command, context.HandlingKernel)); - // return Task.CompletedTask; - // }, childKernel.Name); - // var command1 = new TestCommand((command, context) => - // { - // commandHandledList.Add((command, context.HandlingKernel)); - // return Task.CompletedTask; - // }, childKernel.Name); - - // scheduler.DeferCommand(deferredCommand1, childKernel); - // scheduler.DeferCommand(deferredCommand2, parentKernel); - // scheduler.DeferCommand(deferredCommand3, childKernel); - // await scheduler.Schedule(command1, childKernel); - - // commandHandledList.Select(e => e.command).Should().BeEquivalentSequenceTo(deferredCommand1, deferredCommand2, deferredCommand3, command1); - - // commandHandledList.Select(e => e.kernel).Should().BeEquivalentSequenceTo(childKernel, parentKernel, childKernel, childKernel); - //} } - } \ No newline at end of file +} \ No newline at end of file diff --git a/src/Microsoft.DotNet.Interactive/KernelScheduler.cs b/src/Microsoft.DotNet.Interactive/KernelScheduler.cs index 31aea895aa..a4fbcc740c 100644 --- a/src/Microsoft.DotNet.Interactive/KernelScheduler.cs +++ b/src/Microsoft.DotNet.Interactive/KernelScheduler.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; @@ -73,37 +74,43 @@ private async Task ProcessScheduledOperations(CancellationToken cancellationToke } else { - return; } } - if (operation is not null) + try { - // get all deferred operations and pump in - foreach (var deferredOperationRegistration in _deferredOperationRegistrations) + if (operation is not null) { - foreach (var deferred in deferredOperationRegistration.GetDeferredOperations(operation.Value, operation.Scope)) + // get all deferred operations and pump in + foreach (var deferredOperationRegistration in _deferredOperationRegistrations) { - var deferredOperation = new ScheduledOperation(deferred, deferredOperationRegistration.OnExecute, operation.Scope); - - cancellationToken.Register(() => + foreach (var deferred in deferredOperationRegistration.GetDeferredOperations(operation.Value, + operation.Scope)) { - if (!deferredOperation.CompletionSource.Task.IsCompleted) + var deferredOperation = new ScheduledOperation(deferred, + deferredOperationRegistration.OnExecute, operation.Scope); + + cancellationToken.Register(() => { - deferredOperation.CompletionSource.SetCanceled(); - } - }); + if (!deferredOperation.CompletionSource.Task.IsCompleted) + { + deferredOperation.CompletionSource.SetCanceled(); + } + }); - await DoWork(deferredOperation); + await DoWork(deferredOperation); + } } - } - - await DoWork(operation); + await DoWork(operation); + } + } + catch + { + Cancel(); + throw; } - - async Task DoWork(ScheduledOperation scheduleOperation) { @@ -118,6 +125,7 @@ async Task DoWork(ScheduledOperation scheduleOperation) catch (Exception e) { scheduleOperation.CompletionSource.SetException(e); + throw; } } }