Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use AsyncCollection in shared helper #7055

Merged
merged 3 commits into from
Jul 31, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
<ItemGroup>
<Compile Include="$(AzureCoreSharedSources)ArrayBufferWriter.cs" />
<Compile Include="$(AzureCoreSharedSources)HashCodeBuilder.cs" />
<Compile Include="$(AzureCoreSharedSources)PageResponse.cs" />
<Compile Include="$(AzureCoreSharedSources)PageResponseEnumerator.cs" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ public virtual Response<ConfigurationSetting> Get(string key, string label = def
/// </summary>
/// <param name="selector">Set of options for selecting <see cref="ConfigurationSetting"/> from the configuration store.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
public virtual IAsyncEnumerable<Response<ConfigurationSetting>> GetSettingsAsync(SettingSelector selector, CancellationToken cancellationToken = default)
public virtual AsyncCollection<ConfigurationSetting> GetSettingsAsync(SettingSelector selector, CancellationToken cancellationToken = default)
{
return PageResponseEnumerator.CreateAsyncEnumerable(nextLink => GetSettingsPageAsync(selector, nextLink, cancellationToken));
}
Expand All @@ -589,7 +589,7 @@ public virtual IEnumerable<Response<ConfigurationSetting>> GetSettings(SettingSe
/// </summary>
/// <param name="selector">Set of options for selecting <see cref="ConfigurationSetting"/> from the configuration store.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
public virtual IAsyncEnumerable<Response<ConfigurationSetting>> GetRevisionsAsync(SettingSelector selector, CancellationToken cancellationToken = default)
public virtual AsyncCollection<ConfigurationSetting> GetRevisionsAsync(SettingSelector selector, CancellationToken cancellationToken = default)
{
return PageResponseEnumerator.CreateAsyncEnumerable(nextLink => GetRevisionsPageAsync(selector, nextLink, cancellationToken));
}
Expand Down Expand Up @@ -630,7 +630,7 @@ private Request CreateGetRequest(string key, string label, DateTimeOffset accept
/// <param name="selector">Set of options for selecting settings from the configuration store.</param>
/// <param name="pageLink"></param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
private async Task<PageResponse<ConfigurationSetting>> GetSettingsPageAsync(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
private async Task<Page<ConfigurationSetting>> GetSettingsPageAsync(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.ApplicationModel.Configuration.ConfigurationClient.GetSettingsPage");
scope.Start();
Expand All @@ -645,7 +645,7 @@ private async Task<PageResponse<ConfigurationSetting>> GetSettingsPageAsync(Sett
case 200:
case 206:
SettingBatch settingBatch = await ConfigurationServiceSerializer.ParseBatchAsync(response, cancellationToken).ConfigureAwait(false);
return new PageResponse<ConfigurationSetting>(settingBatch.Settings, response, settingBatch.NextBatchLink);
return new Page<ConfigurationSetting>(settingBatch.Settings, settingBatch.NextBatchLink, response);
default:
throw await response.CreateRequestFailedExceptionAsync().ConfigureAwait(false);
}
Expand All @@ -663,7 +663,7 @@ private async Task<PageResponse<ConfigurationSetting>> GetSettingsPageAsync(Sett
/// <param name="selector">Set of options for selecting settings from the configuration store.</param>
/// <param name="pageLink"></param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
private PageResponse<ConfigurationSetting> GetSettingsPage(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
private Page<ConfigurationSetting> GetSettingsPage(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.ApplicationModel.Configuration.ConfigurationClient.GetSettingsPage");
scope.Start();
Expand All @@ -678,7 +678,7 @@ private PageResponse<ConfigurationSetting> GetSettingsPage(SettingSelector selec
case 200:
case 206:
SettingBatch settingBatch = ConfigurationServiceSerializer.ParseBatch(response);
return new PageResponse<ConfigurationSetting>(settingBatch.Settings, response, settingBatch.NextBatchLink);
return new Page<ConfigurationSetting>(settingBatch.Settings, settingBatch.NextBatchLink, response);
default:
throw response.CreateRequestFailedException();
}
Expand Down Expand Up @@ -712,7 +712,7 @@ private Request CreateBatchRequest(SettingSelector selector, string pageLink)
/// <param name="selector">Set of options for selecting settings from the configuration store.</param>
/// <param name="pageLink"></param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
private async Task<PageResponse<ConfigurationSetting>> GetRevisionsPageAsync(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
private async Task<Page<ConfigurationSetting>> GetRevisionsPageAsync(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.ApplicationModel.Configuration.ConfigurationClient.GetRevisionsPage");
scope.Start();
Expand All @@ -726,7 +726,7 @@ private async Task<PageResponse<ConfigurationSetting>> GetRevisionsPageAsync(Set
case 200:
case 206:
SettingBatch settingBatch = await ConfigurationServiceSerializer.ParseBatchAsync(response, cancellationToken).ConfigureAwait(false);
return new PageResponse<ConfigurationSetting>(settingBatch.Settings, response, settingBatch.NextBatchLink);
return new Page<ConfigurationSetting>(settingBatch.Settings, settingBatch.NextBatchLink, response);
default:
throw await response.CreateRequestFailedExceptionAsync().ConfigureAwait(false);
}
Expand All @@ -745,7 +745,7 @@ private async Task<PageResponse<ConfigurationSetting>> GetRevisionsPageAsync(Set
/// <param name="selector">Set of options for selecting settings from the configuration store.</param>
/// <param name="pageLink"></param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
private PageResponse<ConfigurationSetting> GetRevisionsPage(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
private Page<ConfigurationSetting> GetRevisionsPage(SettingSelector selector, string pageLink, CancellationToken cancellationToken = default)
{
using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.ApplicationModel.Configuration.ConfigurationClient.GetRevisionsPage");
scope.Start();
Expand All @@ -759,7 +759,7 @@ private PageResponse<ConfigurationSetting> GetRevisionsPage(SettingSelector sele
case 200:
case 206:
SettingBatch settingBatch = ConfigurationServiceSerializer.ParseBatch(response);
return new PageResponse<ConfigurationSetting>(settingBatch.Settings, response, settingBatch.NextBatchLink);
return new Page<ConfigurationSetting>(settingBatch.Settings, settingBatch.NextBatchLink, response);
default:
throw response.CreateRequestFailedException();
}
Expand Down
11 changes: 10 additions & 1 deletion sdk/core/Azure.Core/src/AsyncCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,16 @@ public abstract IAsyncEnumerable<Page<T>> ByPage(
/// enumerating asynchronously.
/// </param>
/// <returns>An async sequence of values.</returns>
public abstract IAsyncEnumerator<Response<T>> GetAsyncEnumerator(CancellationToken cancellationToken = default);
public virtual async IAsyncEnumerator<Response<T>> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
await foreach (Page<T> page in ByPage())
pakrym marked this conversation as resolved.
Show resolved Hide resolved
{
foreach (T value in page.Values)
{
yield return new Response<T>(page.GetRawResponse(), value);
}
}
}

/// <summary>
/// Creates a string representation of an <see cref="AsyncCollection{T}"/>.
Expand Down
21 changes: 0 additions & 21 deletions sdk/core/Azure.Core/src/Shared/PageResponse.cs

This file was deleted.

39 changes: 26 additions & 13 deletions sdk/core/Azure.Core/src/Shared/PageResponseEnumerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,52 @@
// Licensed under the MIT License.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

namespace Azure.Core
{
internal static class PageResponseEnumerator
{
public static IEnumerable<Response<T>> CreateEnumerable<T>(Func<string, PageResponse<T>> pageFunc)
public static IEnumerable<Response<T>> CreateEnumerable<T>(Func<string, Page<T>> pageFunc)
{
string nextLink = null;
do
{
PageResponse<T> pageResponse = pageFunc(nextLink);
Page<T> pageResponse = pageFunc(nextLink);
foreach (T setting in pageResponse.Values)
{
yield return new Response<T>(pageResponse.Response, setting);
yield return new Response<T>(pageResponse.GetRawResponse(), setting);
}
nextLink = pageResponse.NextLink;
nextLink = pageResponse.ContinuationToken;
} while (nextLink != null);
}

public static async IAsyncEnumerable<Response<T>> CreateAsyncEnumerable<T>(Func<string, Task<PageResponse<T>>> pageFunc)
public static AsyncCollection<T> CreateAsyncEnumerable<T>(Func<string, Task<Page<T>>> pageFunc)
{
string nextLink = null;
do
return new FuncAsyncCollection<T>(pageFunc);
}

internal class FuncAsyncCollection<T>: AsyncCollection<T>
{
private readonly Func<string, Task<Page<T>>> _pageFunc;
pakrym marked this conversation as resolved.
Show resolved Hide resolved

public FuncAsyncCollection(Func<string, Task<Page<T>>> pageFunc)
{
PageResponse<T> pageResponse = await pageFunc(nextLink).ConfigureAwait(false);
foreach (T setting in pageResponse.Values)
_pageFunc = pageFunc;
}

public override async IAsyncEnumerable<Page<T>> ByPage(string continuationToken = default, int? pageSizeHint = default)
{
do
{
yield return new Response<T>(pageResponse.Response, setting);
}
nextLink = pageResponse.NextLink;
} while (nextLink != null);
Page<T> pageResponse = await _pageFunc(continuationToken).ConfigureAwait(false);
yield return pageResponse;
continuationToken = pageResponse.ContinuationToken;
} while (continuationToken != null);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,10 @@ public void Intercept(IInvocation invocation)
// Map IEnumerable to IAsyncEnumerable
if (returnsIEnumerable)
{
if (invocation.Method.ReturnType.IsGenericType &&
invocation.Method.ReturnType.GetGenericTypeDefinition().Name == "AsyncCollection`1")
{
// AsyncCollection can be used as either a sync or async
// collection so there's no need to wrap it in an
// IAsyncEnumerable
invocation.ReturnValue = result;
}
else
{
invocation.ReturnValue = Activator.CreateInstance(
typeof(AsyncEnumerableWrapper<>).MakeGenericType(returnType.GenericTypeArguments),
new[] { result });
}
Type[] modelType = returnType.GenericTypeArguments.Single().GenericTypeArguments;
Type wrapperType = typeof(AsyncEnumerableWrapper<>).MakeGenericType(modelType);

invocation.ReturnValue = Activator.CreateInstance(wrapperType, new [] { result });
}
else
{
Expand All @@ -120,34 +110,47 @@ private static MethodInfo GetMethod(IInvocation invocation, string nonAsyncMetho
return invocation.TargetType.GetMethod(nonAsyncMethodName, BindingFlags.Public | BindingFlags.Instance, null, types, null);
}

private class AsyncEnumerableWrapper<T> : IAsyncEnumerable<T>
private class AsyncEnumerableWrapper<T> : AsyncCollection<T>
{
private readonly IEnumerable<T> _enumerable;
private readonly IEnumerable<Response<T>> _enumerable;

public AsyncEnumerableWrapper(IEnumerable<T> enumerable)
public AsyncEnumerableWrapper(IEnumerable<Response<T>> enumerable)
{
_enumerable = enumerable;
}

public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = new CancellationToken())
#pragma warning disable 1998
public override async IAsyncEnumerable<Page<T>> ByPage(string continuationToken = default, int? pageSizeHint = default)
#pragma warning restore 1998
{
return new Enumerator(_enumerable.GetEnumerator());
if (continuationToken != null)
{
throw new InvalidOperationException("Calling ByPage with a continuationToken is not supported in the sync mode");
}

foreach (Response<T> response in _enumerable)
{
yield return new Page<T>(new [] { response.Value}, null, response.GetRawResponse());
}
}

private class Enumerator: IAsyncEnumerator<T>
private class SingleEnumerable: IAsyncEnumerable<Page<T>>, IAsyncEnumerator<Page<T>>
{
private readonly IEnumerator<T> _enumerator;

public Enumerator(IEnumerator<T> enumerator)
public SingleEnumerable(Page<T> value)
{
_enumerator = enumerator;
Current = value;
}

public ValueTask DisposeAsync() => default;

public ValueTask<bool> MoveNextAsync() => new ValueTask<bool>(_enumerator.MoveNext());
public ValueTask<bool> MoveNextAsync() => new ValueTask<bool>(false);

public Page<T> Current { get; }

public T Current => _enumerator.Current;
public IAsyncEnumerator<Page<T>> GetAsyncEnumerator(CancellationToken cancellationToken = new CancellationToken())
{
return this;
}
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion sdk/core/Microsoft.Extensions.Azure/samples/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public Startup(IConfiguration configuration)

// This method gets called by the runtime. Use this method to add services to the container.
// For more information on how to configure your application, visit https://go.microsoft.com/fwlink/?LinkID=398940
public void ConfigureServices(IServiceCollection services)
public IServiceProvider ConfigureServices(IServiceCollection services)
pakrym marked this conversation as resolved.
Show resolved Hide resolved
{
// Registering policy to use in ConfigureDefaults later
services.AddSingleton<DependencyInjectionEnabledPolicy>();
Expand Down Expand Up @@ -53,6 +53,8 @@ public void ConfigureServices(IServiceCollection services)
builder.AddBlobServiceClient(Configuration.GetSection("Storage"))
.WithVersion(BlobClientOptions.ServiceVersion.V2018_11_09);
});

return services.BuildServiceProvider();
}

// This method gets called by the runtime. Use this method to configure the HTTP request pipeline.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

<ItemGroup>
<Compile Include="$(AzureCoreSharedSources)ArrayBufferWriter.cs" />
<Compile Include="$(AzureCoreSharedSources)PageResponse.cs" />
<Compile Include="$(AzureCoreSharedSources)PageResponseEnumerator.cs" />
</ItemGroup>

Expand Down
6 changes: 3 additions & 3 deletions sdk/keyvault/Azure.Security.KeyVault.Keys/src/KeyClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ public virtual IEnumerable<Response<KeyBase>> GetKeys(CancellationToken cancella
/// permission.
/// </remarks>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
public virtual IAsyncEnumerable<Response<KeyBase>> GetKeysAsync(CancellationToken cancellationToken = default)
public virtual AsyncCollection<KeyBase> GetKeysAsync(CancellationToken cancellationToken = default)
pakrym marked this conversation as resolved.
Show resolved Hide resolved
{
Uri firstPageUri = CreateFirstPageUri(KeysPath);

Expand Down Expand Up @@ -438,7 +438,7 @@ public virtual IEnumerable<Response<KeyBase>> GetKeyVersions(string name, Cancel
/// </remarks>
/// <param name="name">The name of the key.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
public virtual IAsyncEnumerable<Response<KeyBase>> GetKeyVersionsAsync(string name, CancellationToken cancellationToken = default)
public virtual AsyncCollection<KeyBase> GetKeyVersionsAsync(string name, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(name)) throw new ArgumentException($"{nameof(name)} can't be empty or null");

Expand Down Expand Up @@ -600,7 +600,7 @@ public virtual IEnumerable<Response<DeletedKey>> GetDeletedKeys(CancellationToke
/// vault. This operation requires the keys/list permission.
/// </remarks>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> controlling the request lifetime.</param>
public virtual IAsyncEnumerable<Response<DeletedKey>> GetDeletedKeysAsync(CancellationToken cancellationToken = default)
public virtual AsyncCollection<DeletedKey> GetDeletedKeysAsync(CancellationToken cancellationToken = default)
{
Uri firstPageUri = CreateFirstPageUri(DeletedKeysPath);

Expand Down
Loading