Skip to content

Commit

Permalink
Update distributed tests to use AAD auth (#9207)
Browse files Browse the repository at this point in the history
* Update distributed tests to use AAD auth
  • Loading branch information
benjaminpetit authored Nov 2, 2024
1 parent c78ee8c commit 86f0ca5
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 94 deletions.
8 changes: 4 additions & 4 deletions distributed-tests.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
variables:
clusterId: '{{ "now" | date: "%s" }}'
serviceId: '{{ "now" | date: "%s" }}'
secretSource: KeyVault
framework: net8.0

jobs:
Expand All @@ -11,7 +10,7 @@ jobs:
executable: DistributedTests.Server.exe
readyStateText: Orleans Silo started.
framework: net8.0
arguments: "{{configurator}} --clusterId {{clusterId}} --serviceId {{serviceId}} --secretSource {{secretSource}} {{configuratorOptions}}"
arguments: "{{configurator}} --clusterId {{clusterId}} --serviceId {{serviceId}} --azureQueueUri {{azureQueueUri}} --azureTableUri {{azureTableUri}} {{configuratorOptions}}"
onConfigure:
- if (job.endpoints.Count > 0) {
job.endpoints.RemoveRange(job.variables.instances, job.endpoints.Count - job.variables.instances);
Expand All @@ -22,7 +21,7 @@ jobs:
executable: DistributedTests.Client.exe
waitForExit: true
framework: net8.0
arguments: "{{command}} --clusterId {{clusterId}} --serviceId {{serviceId}} --secretSource {{secretSource}} {{commandOptions}}"
arguments: "{{command}} --clusterId {{clusterId}} --serviceId {{serviceId}} --azureQueueUri {{azureQueueUri}} --azureTableUri {{azureTableUri}} {{commandOptions}}"
onConfigure:
- if (job.endpoints.Count > 0) {
job.endpoints.Reverse();
Expand Down Expand Up @@ -173,7 +172,8 @@ results:
profiles:
local:
variables:
secretSource: File
azureQueueUri: "http://127.0.0.1:10001"
azureTableUri: "http://127.0.0.1:10002"
jobs:
server:
endpoints:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ private class Parameters
{
public string ServiceId { get; set; }
public string ClusterId { get; set; }
public SecretConfiguration.SecretSource SecretSource { get; set; }
public Uri AzureTableUri { get; set; }
public Uri AzureQueueUri { get; set; }
public int Wait { get; set; }
public int ServersPerRound { get; set; }
public int Rounds { get; set; }
Expand All @@ -27,7 +28,8 @@ public ChaosAgentCommand(ILogger logger)
{
AddOption(OptionHelper.CreateOption<string>("--serviceId", isRequired: true));
AddOption(OptionHelper.CreateOption<string>("--clusterId", isRequired: true));
AddOption(OptionHelper.CreateOption("--secretSource", defaultValue: SecretConfiguration.SecretSource.File));
AddOption(OptionHelper.CreateOption<Uri>("--azureTableUri", isRequired: true));
AddOption(OptionHelper.CreateOption<Uri>("--azureQueueUri", isRequired: true));
AddOption(OptionHelper.CreateOption<int>("--wait", defaultValue: 30));
AddOption(OptionHelper.CreateOption<int>("--serversPerRound", defaultValue: 1));
AddOption(OptionHelper.CreateOption<int>("--rounds", defaultValue: 5));
Expand All @@ -41,8 +43,7 @@ public ChaosAgentCommand(ILogger logger)

private async Task RunAsync(Parameters parameters)
{
var secrets = SecretConfiguration.Load(parameters.SecretSource);
var channel = await Channels.CreateSendChannel(parameters.ClusterId, secrets);
var channel = await Channels.CreateSendChannel(parameters.ClusterId, parameters.AzureQueueUri);

_logger.LogInformation("Waiting {WaitSeconds} seconds before starting...", parameters.Wait);
await Task.Delay(TimeSpan.FromSeconds(parameters.Wait));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.CommandLine;
using System.CommandLine.Invocation;
using DistributedTests.Common;
using DistributedTests.GrainInterfaces;
using Microsoft.Crank.EventSources;
using Microsoft.Extensions.DependencyInjection;
Expand All @@ -17,7 +18,8 @@ private class Parameters
{
public string ServiceId { get; set; }
public string ClusterId { get; set; }
public SecretConfiguration.SecretSource SecretSource { get; set; }
public Uri AzureTableUri { get; set; }
public Uri AzureQueueUri { get; set; }
public string CounterKey { get; set; }
public List<string> Counters { get; set; }
}
Expand All @@ -27,7 +29,8 @@ public CounterCaptureCommand(ILogger logger)
{
AddOption(OptionHelper.CreateOption<string>("--serviceId", isRequired: true));
AddOption(OptionHelper.CreateOption<string>("--clusterId", isRequired: true));
AddOption(OptionHelper.CreateOption("--secretSource", defaultValue: SecretConfiguration.SecretSource.File));
AddOption(OptionHelper.CreateOption<Uri>("--azureTableUri", isRequired: true));
AddOption(OptionHelper.CreateOption<Uri>("--azureQueueUri", isRequired: true));
AddOption(OptionHelper.CreateOption("--counterKey", defaultValue: StreamingConstants.DefaultCounterGrain));
AddArgument(new Argument<List<string>>("Counters") { Arity = ArgumentArity.OneOrMore });

Expand All @@ -38,12 +41,11 @@ public CounterCaptureCommand(ILogger logger)
private async Task RunAsync(Parameters parameters)
{
_logger.LogInformation("Connecting to cluster...");
var secrets = SecretConfiguration.Load(parameters.SecretSource);
var hostBuilder = new HostBuilder()
.UseOrleansClient((ctx, builder) => {
builder
.Configure<ClusterOptions>(options => { options.ClusterId = parameters.ClusterId; options.ServiceId = parameters.ServiceId; })
.UseAzureStorageClustering(options => options.TableServiceClient = new(secrets.ClusteringConnectionString));
.UseAzureStorageClustering(options => options.TableServiceClient = parameters.AzureTableUri.CreateTableServiceClient());
});
using var host = hostBuilder.Build();
await host.StartAsync();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ public ScenarioCommand(ILoadGeneratorScenario<T> scenario, ILoggerFactory logger
AddOption(OptionHelper.CreateOption<string>("--serviceId", isRequired: true));
AddOption(OptionHelper.CreateOption<string>("--clusterId", isRequired: true));
AddOption(OptionHelper.CreateOption<int>("--connectionsPerEndpoint", defaultValue: 1, validator: OptionHelper.OnlyStrictlyPositive));
AddOption(OptionHelper.CreateOption("--secretSource", defaultValue: SecretConfiguration.SecretSource.File));
AddOption(OptionHelper.CreateOption<Uri>("--azureQueueUri", isRequired: true));
AddOption(OptionHelper.CreateOption<Uri>("--azureTableUri", isRequired: true));

// LoadGeneratorParameters
AddOption(OptionHelper.CreateOption<int>("--numWorkers", defaultValue: 250, validator: OptionHelper.OnlyStrictlyPositive));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Azure.Identity;
using DistributedTests.Common;
using Microsoft.Crank.EventSources;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
Expand All @@ -11,7 +13,8 @@ public class ClientParameters
public string ServiceId { get; set; }
public string ClusterId { get; set; }
public int ConnectionsPerEndpoint { get; set; }
public SecretConfiguration.SecretSource SecretSource { get;set; }
public Uri AzureTableUri { get; set; }
public Uri AzureQueueUri { get; set; }
}

public class LoadGeneratorParameters
Expand All @@ -35,16 +38,17 @@ public LoadGeneratorScenarioRunner(ILoadGeneratorScenario<T> scenario, ILoggerFa

public async Task Run(ClientParameters clientParams, LoadGeneratorParameters loadParams)
{
Console.WriteLine($"AzureTableUri: {clientParams.AzureTableUri}");

// Register the measurements. n0 -> format as natural number
BenchmarksEventSource.Register("requests", Operations.Sum, Operations.Sum, "Requests", "Number of requests completed", "n0");
BenchmarksEventSource.Register("failures", Operations.Sum, Operations.Sum, "Failures", "Number of failures", "n0");
BenchmarksEventSource.Register("rps", Operations.Sum, Operations.Median, "Median RPS", "Rate per second", "n0");

var secrets = SecretConfiguration.Load(clientParams.SecretSource);
var hostBuilder = new HostBuilder().UseOrleansClient((ctx, builder) =>
builder.Configure<ClusterOptions>(options => { options.ClusterId = clientParams.ClusterId; options.ServiceId = clientParams.ServiceId; })
.Configure<ConnectionOptions>(options => clientParams.ConnectionsPerEndpoint = 2)
.UseAzureStorageClustering(options => options.TableServiceClient = new(secrets.ClusteringConnectionString)));
.UseAzureStorageClustering(options => options.TableServiceClient = clientParams.AzureTableUri.CreateTableServiceClient()));
using var host = hostBuilder.Build();

_logger.LogInformation("Connecting to cluster...");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Logging.Console" />
<PackageReference Include="Azure.Storage.Queues" />
<PackageReference Include="Azure.Data.Tables" />
<PackageReference Include="Azure.Identity" />
<PackageReference Include="Azure.Security.KeyVault.Secrets" />
<PackageReference Include="Microsoft.Extensions.Configuration" />
<PackageReference Include="Microsoft.Extensions.Configuration.AzureKeyVault" />
<PackageReference Include="Microsoft.Extensions.Configuration.Json" />
<PackageReference Include="System.CommandLine" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Diagnostics;
using System.Text.Json;
using Azure.Identity;
using Azure.Storage.Queues;

namespace DistributedTests.Common.MessageChannel
Expand Down Expand Up @@ -69,21 +71,27 @@ public static class Channels
private static readonly string CLIENT_TO_SERVER_QUEUE = "servers-{0}";
private static readonly string SILO_TO_CLIENT_QUEUE = "client-{0}";

public static async Task<ISendChannel> CreateSendChannel(string clusterId, SecretConfiguration configuration)
public static Task<ISendChannel> CreateSendChannel(string clusterId, Uri azureQueueUri)
=> CreateSendChannel(clusterId, azureQueueUri.CreateQueueServiceClient());

public static async Task<ISendChannel> CreateSendChannel(string clusterId, QueueServiceClient queueServiceClient)
{
var writeQueue = new QueueClient(configuration.ClusteringConnectionString, string.Format(CLIENT_TO_SERVER_QUEUE, clusterId));
var readQueue = new QueueClient(configuration.ClusteringConnectionString, string.Format(SILO_TO_CLIENT_QUEUE, clusterId));
var writeQueue = queueServiceClient.GetQueueClient(string.Format(CLIENT_TO_SERVER_QUEUE, clusterId));
var readQueue = queueServiceClient.GetQueueClient(string.Format(SILO_TO_CLIENT_QUEUE, clusterId));

await writeQueue.CreateIfNotExistsAsync();
await readQueue.CreateIfNotExistsAsync();

return new SendChannel(writeQueue, readQueue);
}

public static async Task<IReceiveChannel> CreateReceiveChannel(string serverName, string clusterId, SecretConfiguration configuration)
public static Task<IReceiveChannel> CreateReceiveChannel(string serverName, string clusterId, Uri azureQueueUri)
=> CreateReceiveChannel(serverName, clusterId, azureQueueUri.CreateQueueServiceClient());

public static async Task<IReceiveChannel> CreateReceiveChannel(string serverName, string clusterId, QueueServiceClient queueServiceClient)
{
var readQueue = new QueueClient(configuration.ClusteringConnectionString, string.Format(CLIENT_TO_SERVER_QUEUE, clusterId));
var writeQueue = new QueueClient(configuration.ClusteringConnectionString, string.Format(SILO_TO_CLIENT_QUEUE, clusterId));
var writeQueue = queueServiceClient.GetQueueClient(string.Format(SILO_TO_CLIENT_QUEUE, clusterId));
var readQueue = queueServiceClient.GetQueueClient(string.Format(CLIENT_TO_SERVER_QUEUE, clusterId));

await writeQueue.CreateIfNotExistsAsync();
await readQueue.CreateIfNotExistsAsync();
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Data.Tables;
using Azure.Identity;
using Azure.Storage.Queues;

namespace DistributedTests.Common;

public static class TokenCredentialHelper
{
private static string EmulatorConnectionString = "UseDevelopmentStorage=true";

public static TableServiceClient CreateTableServiceClient(this Uri azureTableUri)
{
if (azureTableUri.IsLoopback)
{
// Assume it's the emulator/azurite
return new TableServiceClient(EmulatorConnectionString);
}
return new TableServiceClient(azureTableUri, GetTokenCredential());
}

public static QueueServiceClient CreateQueueServiceClient(this Uri azureQueueUri)
{
if (azureQueueUri.IsLoopback)
{
// Assume it's the emulator/azurite
return new QueueServiceClient(EmulatorConnectionString);
}
return new QueueServiceClient(azureQueueUri, GetTokenCredential());
}

public static TokenCredential GetTokenCredential()
{
var tenantId = Environment.GetEnvironmentVariable("TENANT_ID");
var clientId = Environment.GetEnvironmentVariable("CLIENT_ID");
if (tenantId != null && clientId != null)
{
// Uses Federated Id Creds, from here:
// https://review.learn.microsoft.com/en-us/identity/microsoft-identity-platform/federated-identity-credentials?branch=main&tabs=dotnet#azure-sdk-for-net
return new ClientAssertionCredential(
tenantId, // Tenant ID for destination resource
clientId, // Client ID of the app we're federating to
() => GetManagedIdentityToken(null, "api://AzureADTokenExchange")) // null here for default MSI
;
}
else
{
return new DefaultAzureCredential();
}
}

/// <summary>
/// Gets a token for the user-assigned Managed Identity.
/// </summary>
/// <param name="msiClientId">Client ID for the Managed Identity.</param>
/// <param name="audience">Target audience. For public clouds should be api://AzureADTokenExchange.</param>
/// <returns>If successful, returns an access token.</returns>
public static string GetManagedIdentityToken(string msiClientId, string audience)
{
var miCredential = new ManagedIdentityCredential(msiClientId);
return miCredential.GetToken(new TokenRequestContext(new[] { $"{audience}/.default" })).Token;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ public ServerCommand(ISiloConfigurator<T> siloConfigurator)
AddOption(OptionHelper.CreateOption<string>("--clusterId", isRequired: true));
AddOption(OptionHelper.CreateOption("--siloPort", defaultValue: 11111));
AddOption(OptionHelper.CreateOption("--gatewayPort", defaultValue: 30000));
AddOption(OptionHelper.CreateOption("--secretSource", defaultValue: SecretConfiguration.SecretSource.File));
AddOption(OptionHelper.CreateOption<Uri>("--azureQueueUri", isRequired: true));
AddOption(OptionHelper.CreateOption<Uri>("--azureTableUri", isRequired: true));
AddOption(OptionHelper.CreateOption("--activationRepartitioning", defaultValue: false));

foreach (var opt in siloConfigurator.Options)
Expand Down
Loading

0 comments on commit 86f0ca5

Please sign in to comment.