From 02836a4623b193cfd561712a8f3e1b1fd0560573 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20Tr=C3=BCtschel?= Date: Thu, 21 Jun 2018 14:26:00 +0200 Subject: [PATCH] Fixes #1 --- src/AsyncEnumerator/AsyncEnumerator.cs | 10 ++++------ .../AsyncEnumeratorTests.cs | 16 ++++++++++++++++ src/AsyncEnumeratorTests/AsyncSequenceTests.cs | 17 +++++++++++++++++ 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/AsyncEnumerator/AsyncEnumerator.cs b/src/AsyncEnumerator/AsyncEnumerator.cs index a99a848..f471396 100644 --- a/src/AsyncEnumerator/AsyncEnumerator.cs +++ b/src/AsyncEnumerator/AsyncEnumerator.cs @@ -18,7 +18,7 @@ public class AsyncEnumerator : TaskLikeBase, IAsyncEnumeratorProducer, IAs private ExceptionDispatchInfo _exception; private bool _isStarted; - private TaskCompletionSource _nextSource; + private TaskCompletionSource _nextSource =new TaskCompletionSource(); private TaskCompletionSource _yieldSource; public static TaskProvider> Capture() => TaskProvider>.Instance; @@ -32,7 +32,7 @@ public Task MoveNextAsync() if (!_isStarted) { _isStarted = true; - return Task.FromResult(true); + return _nextSource.Task; } _nextSource = new TaskCompletionSource(); @@ -45,7 +45,7 @@ public Task MoveNextAsync() internal override void SetException(ExceptionDispatchInfo exception) { _exception = exception; - _nextSource?.TrySetException(exception.SourceException); + _nextSource.TrySetException(exception.SourceException); } T IAsyncEnumeratorProducer.Break() @@ -65,10 +65,8 @@ Task IAsyncEnumeratorProducer.Pause() Task IAsyncEnumeratorProducer.Return(T value) { Current = value; - _yieldSource = new TaskCompletionSource(); - - _nextSource?.TrySetResult(true); + _nextSource.TrySetResult(true); return _yieldSource.Task; } diff --git a/src/AsyncEnumeratorTests/AsyncEnumeratorTests.cs b/src/AsyncEnumeratorTests/AsyncEnumeratorTests.cs index fc848b9..f6e4243 100644 --- a/src/AsyncEnumeratorTests/AsyncEnumeratorTests.cs +++ b/src/AsyncEnumeratorTests/AsyncEnumeratorTests.cs @@ -54,6 +54,15 @@ public async Task EnumerationAdvancesCorrectlyAndCompletes2() Assert.IsTrue(seq.IsCompleted, "Enumeration did not complete after return."); } + [TestMethod] + public async Task EmptyEnumeratorTest() + { + var seq = GetEmptyEnumerator(); + + Assert.IsFalse(await seq.MoveNextAsync(), $"Call to {nameof(seq.MoveNextAsync)} did not return false after enumeration completed."); + + Assert.IsTrue(seq.IsCompleted, "Enumeration did not complete after return."); + } private static async AsyncEnumerator ExceptionTest1() { @@ -89,5 +98,12 @@ private static async AsyncEnumerator Test2() return yield.Break(); } + private static async AsyncEnumerator GetEmptyEnumerator() + { + var yield = await AsyncEnumerator.Capture(); + + return yield.Break(); + } + } } diff --git a/src/AsyncEnumeratorTests/AsyncSequenceTests.cs b/src/AsyncEnumeratorTests/AsyncSequenceTests.cs index 8e1f6d9..7165d17 100644 --- a/src/AsyncEnumeratorTests/AsyncSequenceTests.cs +++ b/src/AsyncEnumeratorTests/AsyncSequenceTests.cs @@ -54,6 +54,16 @@ public async Task EnumerationAdvancesCorrectlyAndCompletes2() Assert.IsTrue(seq.IsCompleted, "Enumeration did not complete after return."); } + [TestMethod] + public async Task EmptyEnumeratorTest() + { + var seq = GetEmptyEnumerator(); + + Assert.IsFalse(await seq.MoveNextAsync(), $"Call to {nameof(seq.MoveNextAsync)} did not return false after enumeration completed."); + + Assert.IsTrue(seq.IsCompleted, "Enumeration did not complete after return."); + } + private static async AsyncEnumerator ExceptionTest1() { var yield = await AsyncEnumerator.Capture(); @@ -91,5 +101,12 @@ private static async AsyncSequence Test2() return yield.Break(); } + + private static async AsyncSequence GetEmptyEnumerator() + { + var yield = await AsyncSequence.Capture(); + + return yield.Break(); + } } }