Skip to content

Commit

Permalink
Use AsyncCollection in shared helper (Azure#7055)
Browse files Browse the repository at this point in the history
  • Loading branch information
pakrym authored Jul 31, 2019
1 parent 2ed4fbc commit 218a8a5
Show file tree
Hide file tree
Showing 21 changed files with 165 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
<ItemGroup>
<Compile Include="$(AzureCoreSharedSources)ArrayBufferWriter.cs" />
<Compile Include="$(AzureCoreSharedSources)HashCodeBuilder.cs" />
<Compile Include="$(AzureCoreSharedSources)PageResponse.cs" />
<Compile Include="$(AzureCoreSharedSources)PageResponseEnumerator.cs" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ public virtual Response<ConfigurationSetting> Get(string key, string label = def
/// </summary>
/// <param name="selector">Set of options for selecting <see cref="ConfigurationSetting"/> from the configuration store.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
public virtual IAsyncEnumerable<Response<ConfigurationSetting>> GetSettingsAsync(SettingSelector selector, CancellationToken cancellationToken = default)
public virtual AsyncCollection<ConfigurationSetting> GetSettingsAsync(SettingSelector selector, CancellationToken cancellationToken = default)
{
return PageResponseEnumerator.CreateAsyncEnumerable(nextLink => GetSettingsPageAsync(selector, nextLink, cancellationToken));
}
Expand All @@ -589,7 +589,7 @@ public virtual IEnumerable<Response<ConfigurationSetting>> GetSettings(SettingSe
/// </summary>
/// <param name="selector">Set of options for selecting <see cref="ConfigurationSetting"/> from the configuration store.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
public virtual IAsyncEnumerable<Response<ConfigurationSetting>> GetRevisionsAsync(SettingSelector selector, CancellationToken cancellationToken = default)
public virtual AsyncCollection<ConfigurationSetting> GetRevisionsAsync(SettingSelector selector, CancellationToken cancellationToken = default)
{
return PageResponseEnumerator.CreateAsyncEnumerable(nextLink => GetRevisionsPageAsync(selector, nextLink, cancellationToken));
}
Expand Down Expand Up @@ -630,7 +630,7 @@ private Request CreateGetRequest(string key, string label, DateTimeOffset accept
/// <param name="selector">Set of options for selecting settings from the configuration store.</param>
/// <param name="pageLink"></param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
private async Task<PageResponse<ConfigurationSetting>> GetSettingsPageAsync(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
private async Task<Page<ConfigurationSetting>> GetSettingsPageAsync(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.ApplicationModel.Configuration.ConfigurationClient.GetSettingsPage");
scope.Start();
Expand All @@ -645,7 +645,7 @@ private async Task<PageResponse<ConfigurationSetting>> GetSettingsPageAsync(Sett
case 200:
case 206:
SettingBatch settingBatch = await ConfigurationServiceSerializer.ParseBatchAsync(response, cancellationToken).ConfigureAwait(false);
return new PageResponse<ConfigurationSetting>(settingBatch.Settings, response, settingBatch.NextBatchLink);
return new Page<ConfigurationSetting>(settingBatch.Settings, settingBatch.NextBatchLink, response);
default:
throw await response.CreateRequestFailedExceptionAsync().ConfigureAwait(false);
}
Expand All @@ -663,7 +663,7 @@ private async Task<PageResponse<ConfigurationSetting>> GetSettingsPageAsync(Sett
/// <param name="selector">Set of options for selecting settings from the configuration store.</param>
/// <param name="pageLink"></param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
private PageResponse<ConfigurationSetting> GetSettingsPage(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
private Page<ConfigurationSetting> GetSettingsPage(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.ApplicationModel.Configuration.ConfigurationClient.GetSettingsPage");
scope.Start();
Expand All @@ -678,7 +678,7 @@ private PageResponse<ConfigurationSetting> GetSettingsPage(SettingSelector selec
case 200:
case 206:
SettingBatch settingBatch = ConfigurationServiceSerializer.ParseBatch(response);
return new PageResponse<ConfigurationSetting>(settingBatch.Settings, response, settingBatch.NextBatchLink);
return new Page<ConfigurationSetting>(settingBatch.Settings, settingBatch.NextBatchLink, response);
default:
throw response.CreateRequestFailedException();
}
Expand Down Expand Up @@ -712,7 +712,7 @@ private Request CreateBatchRequest(SettingSelector selector, string pageLink)
/// <param name="selector">Set of options for selecting settings from the configuration store.</param>
/// <param name="pageLink"></param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
private async Task<PageResponse<ConfigurationSetting>> GetRevisionsPageAsync(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
private async Task<Page<ConfigurationSetting>> GetRevisionsPageAsync(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.ApplicationModel.Configuration.ConfigurationClient.GetRevisionsPage");
scope.Start();
Expand All @@ -726,7 +726,7 @@ private async Task<PageResponse<ConfigurationSetting>> GetRevisionsPageAsync(Set
case 200:
case 206:
SettingBatch settingBatch = await ConfigurationServiceSerializer.ParseBatchAsync(response, cancellationToken).ConfigureAwait(false);
return new PageResponse<ConfigurationSetting>(settingBatch.Settings, response, settingBatch.NextBatchLink);
return new Page<ConfigurationSetting>(settingBatch.Settings, settingBatch.NextBatchLink, response);
default:
throw await response.CreateRequestFailedExceptionAsync().ConfigureAwait(false);
}
Expand All @@ -745,7 +745,7 @@ private async Task<PageResponse<ConfigurationSetting>> GetRevisionsPageAsync(Set
/// <param name="selector">Set of options for selecting settings from the configuration store.</param>
/// <param name="pageLink"></param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
private PageResponse<ConfigurationSetting> GetRevisionsPage(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
private Page<ConfigurationSetting> GetRevisionsPage(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.ApplicationModel.Configuration.ConfigurationClient.GetRevisionsPage");
scope.Start();
Expand All @@ -759,7 +759,7 @@ private PageResponse<ConfigurationSetting> GetRevisionsPage(SettingSelector sele
case 200:
case 206:
SettingBatch settingBatch = ConfigurationServiceSerializer.ParseBatch(response);
return new PageResponse<ConfigurationSetting>(settingBatch.Settings, response, settingBatch.NextBatchLink);
return new Page<ConfigurationSetting>(settingBatch.Settings, settingBatch.NextBatchLink, response);
default:
throw response.CreateRequestFailedException();
}
Expand Down
12 changes: 11 additions & 1 deletion sdk/core/Azure.Core/src/AsyncCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Threading;
using System.Threading.Tasks;

namespace Azure
{
Expand Down Expand Up @@ -66,7 +67,16 @@ public abstract IAsyncEnumerable<Page<T>> ByPage(
/// enumerating asynchronously.
/// </param>
/// <returns>An async sequence of values.</returns>
public abstract IAsyncEnumerator<Response<T>> GetAsyncEnumerator(CancellationToken cancellationToken = default);
public virtual async IAsyncEnumerator<Response<T>> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
await foreach (Page<T> page in ByPage().ConfigureAwait(false).WithCancellation(cancellationToken))
{
foreach (T value in page.Values)
{
yield return new Response<T>(page.GetRawResponse(), value);
}
}
}

/// <summary>
/// Creates a string representation of an <see cref="AsyncCollection{T}"/>.
Expand Down
21 changes: 0 additions & 21 deletions sdk/core/Azure.Core/src/Shared/PageResponse.cs

This file was deleted.

44 changes: 31 additions & 13 deletions sdk/core/Azure.Core/src/Shared/PageResponseEnumerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,57 @@
// Licensed under the MIT License.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

namespace Azure.Core
{
internal static class PageResponseEnumerator
{
public static IEnumerable<Response<T>> CreateEnumerable<T>(Func<string, PageResponse<T>> pageFunc)
public static IEnumerable<Response<T>> CreateEnumerable<T>(Func<string, Page<T>> pageFunc)
{
string nextLink = null;
do
{
PageResponse<T> pageResponse = pageFunc(nextLink);
Page<T> pageResponse = pageFunc(nextLink);
foreach (T setting in pageResponse.Values)
{
yield return new Response<T>(pageResponse.Response, setting);
yield return new Response<T>(pageResponse.GetRawResponse(), setting);
}
nextLink = pageResponse.NextLink;
nextLink = pageResponse.ContinuationToken;
} while (nextLink != null);
}

public static async IAsyncEnumerable<Response<T>> CreateAsyncEnumerable<T>(Func<string, Task<PageResponse<T>>> pageFunc)
public static AsyncCollection<T> CreateAsyncEnumerable<T>(Func<string, Task<Page<T>>> pageFunc)
{
string nextLink = null;
do
return new FuncAsyncCollection<T>((continuationToken, pageSizeHint) => pageFunc(continuationToken));
}

public static AsyncCollection<T> CreateAsyncEnumerable<T>(Func<string, int?, Task<Page<T>>> pageFunc)
{
return new FuncAsyncCollection<T>(pageFunc);
}

internal class FuncAsyncCollection<T>: AsyncCollection<T>
{
private readonly Func<string, int?, Task<Page<T>>> _pageFunc;

public FuncAsyncCollection(Func<string, int?, Task<Page<T>>> pageFunc)
{
PageResponse<T> pageResponse = await pageFunc(nextLink).ConfigureAwait(false);
foreach (T setting in pageResponse.Values)
_pageFunc = pageFunc;
}

public override async IAsyncEnumerable<Page<T>> ByPage(string continuationToken = default, int? pageSizeHint = default)
{
do
{
yield return new Response<T>(pageResponse.Response, setting);
}
nextLink = pageResponse.NextLink;
} while (nextLink != null);
Page<T> pageResponse = await _pageFunc(continuationToken, pageSizeHint).ConfigureAwait(false);
yield return pageResponse;
continuationToken = pageResponse.ContinuationToken;
} while (continuationToken != null);
}
}
}
}
38 changes: 38 additions & 0 deletions sdk/core/Azure.Core/tests/TestFramework/AsyncOnlyAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Text;
using NUnit.Framework;
using NUnit.Framework.Interfaces;
using NUnit.Framework.Internal;

namespace Azure.Core.Testing
{

[AttributeUsage(AttributeTargets.Method|AttributeTargets.Class|AttributeTargets.Assembly, AllowMultiple=false, Inherited=false)]
public class AsyncOnlyAttribute : NUnitAttribute, IApplyToTest
{
#region IApplyToTest members

/// <summary>
/// Modifies a test by marking it as Ignored.
/// </summary>
/// <param name="test">The test to modify</param>
public void ApplyToTest(Test test)
{
if (test.RunState != RunState.NotRunnable)
{
// This is an unfortunate implementation but it's the only one I was able to figure out
string testParameters = test.FullName.Substring(test.ClassName.Length);
if (testParameters.StartsWith("(False"))
{
test.RunState = RunState.Ignored;
}
}
}

#endregion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,10 @@ public void Intercept(IInvocation invocation)
// Map IEnumerable to IAsyncEnumerable
if (returnsIEnumerable)
{
if (invocation.Method.ReturnType.IsGenericType &&
invocation.Method.ReturnType.GetGenericTypeDefinition().Name == "AsyncCollection`1")
{
// AsyncCollection can be used as either a sync or async
// collection so there's no need to wrap it in an
// IAsyncEnumerable
invocation.ReturnValue = result;
}
else
{
invocation.ReturnValue = Activator.CreateInstance(
typeof(AsyncEnumerableWrapper<>).MakeGenericType(returnType.GenericTypeArguments),
new[] { result });
}
Type[] modelType = returnType.GenericTypeArguments.Single().GenericTypeArguments;
Type wrapperType = typeof(AsyncEnumerableWrapper<>).MakeGenericType(modelType);

invocation.ReturnValue = Activator.CreateInstance(wrapperType, new [] { result });
}
else
{
Expand All @@ -120,34 +110,52 @@ private static MethodInfo GetMethod(IInvocation invocation, string nonAsyncMetho
return invocation.TargetType.GetMethod(nonAsyncMethodName, BindingFlags.Public | BindingFlags.Instance, null, types, null);
}

private class AsyncEnumerableWrapper<T> : IAsyncEnumerable<T>
private class AsyncEnumerableWrapper<T> : AsyncCollection<T>
{
private readonly IEnumerable<T> _enumerable;
private readonly IEnumerable<Response<T>> _enumerable;

public AsyncEnumerableWrapper(IEnumerable<T> enumerable)
public AsyncEnumerableWrapper(IEnumerable<Response<T>> enumerable)
{
_enumerable = enumerable;
}

public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = new CancellationToken())
#pragma warning disable 1998
public override async IAsyncEnumerable<Page<T>> ByPage(string continuationToken = default, int? pageSizeHint = default)
#pragma warning restore 1998
{
return new Enumerator(_enumerable.GetEnumerator());
if (continuationToken != null)
{
throw new InvalidOperationException("Calling ByPage with a continuationToken is not supported in the sync mode");
}

if (pageSizeHint != null)
{
throw new InvalidOperationException("Calling ByPage with a pageSizeHint is not supported in the sync mode");
}

foreach (Response<T> response in _enumerable)
{
yield return new Page<T>(new [] { response.Value}, null, response.GetRawResponse());
}
}

private class Enumerator: IAsyncEnumerator<T>
private class SingleEnumerable: IAsyncEnumerable<Page<T>>, IAsyncEnumerator<Page<T>>
{
private readonly IEnumerator<T> _enumerator;

public Enumerator(IEnumerator<T> enumerator)
public SingleEnumerable(Page<T> value)
{
_enumerator = enumerator;
Current = value;
}

public ValueTask DisposeAsync() => default;

public ValueTask<bool> MoveNextAsync() => new ValueTask<bool>(_enumerator.MoveNext());
public ValueTask<bool> MoveNextAsync() => new ValueTask<bool>(false);

public T Current => _enumerator.Current;
public Page<T> Current { get; }

public IAsyncEnumerator<Page<T>> GetAsyncEnumerator(CancellationToken cancellationToken = new CancellationToken())
{
return this;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

<ItemGroup>
<Compile Include="$(AzureCoreSharedSources)ArrayBufferWriter.cs" />
<Compile Include="$(AzureCoreSharedSources)PageResponse.cs" />
<Compile Include="$(AzureCoreSharedSources)PageResponseEnumerator.cs" />
</ItemGroup>

Expand Down
Loading

0 comments on commit 218a8a5

Please sign in to comment.