From 2bc2990146eeb6651133ff2e0c486a7b1c830d68 Mon Sep 17 00:00:00 2001 From: Pavel Krymets Date: Thu, 18 Jul 2019 22:25:11 -0700 Subject: [PATCH] Add identity client distributed tracing instrumentation (#6972) --- .../Azure.Identity/src/AadIdentityClient.cs | 49 +++- .../src/ManagedIdentityClient.cs | 38 ++- .../src/Properties/AssemblyInfo.cs | 1 + .../tests/ManagedIdentityClientTests.cs | 277 ++++++++++++++++++ .../tests/ManagedIdentityCredentialTests.cs | 263 +---------------- .../tests/Mock/AadIdentityClientTests.cs | 179 +++++++++++ .../MockClientCertificateCredentialTests.cs | 119 +------- 7 files changed, 543 insertions(+), 383 deletions(-) create mode 100644 sdk/identity/Azure.Identity/tests/ManagedIdentityClientTests.cs create mode 100644 sdk/identity/Azure.Identity/tests/Mock/AadIdentityClientTests.cs diff --git a/sdk/identity/Azure.Identity/src/AadIdentityClient.cs b/sdk/identity/Azure.Identity/src/AadIdentityClient.cs index 6a7edbcaed1bb..a33396e726444 100644 --- a/sdk/identity/Azure.Identity/src/AadIdentityClient.cs +++ b/sdk/identity/Azure.Identity/src/AadIdentityClient.cs @@ -30,6 +30,10 @@ internal class AadIdentityClient private const string AuthenticationRequestFailedError = "The request to the identity service failed. See inner exception for details."; + protected AadIdentityClient() + { + } + public AadIdentityClient(IdentityClientOptions options = null) { _options = options ?? new IdentityClientOptions(); @@ -42,8 +46,12 @@ public AadIdentityClient(IdentityClientOptions options = null) public virtual async Task AuthenticateAsync(string tenantId, string clientId, string clientSecret, string[] scopes, CancellationToken cancellationToken = default) { - using (Request request = CreateClientSecretAuthRequest(tenantId, clientId, clientSecret, scopes)) + using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.Identity.AadIdentityClient.Authenticate"); + scope.Start(); + + try { + using Request request = CreateClientSecretAuthRequest(tenantId, clientId, clientSecret, scopes); try { return await SendAuthRequestAsync(request, cancellationToken).ConfigureAwait(false); @@ -53,12 +61,21 @@ public virtual async Task AuthenticateAsync(string tenantId, string throw new AuthenticationFailedException(AuthenticationRequestFailedError, ex); } } + catch (Exception e) + { + scope.Failed(e); + throw; + } } public virtual AccessToken Authenticate(string tenantId, string clientId, string clientSecret, string[] scopes, CancellationToken cancellationToken = default) { - using (Request request = CreateClientSecretAuthRequest(tenantId, clientId, clientSecret, scopes)) + using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.Identity.AadIdentityClient.Authenticate"); + scope.Start(); + + try { + using Request request = CreateClientSecretAuthRequest(tenantId, clientId, clientSecret, scopes); try { return SendAuthRequest(request, cancellationToken); @@ -68,12 +85,21 @@ public virtual AccessToken Authenticate(string tenantId, string clientId, string throw new AuthenticationFailedException(AuthenticationRequestFailedError, ex); } } + catch (Exception e) + { + scope.Failed(e); + throw; + } } public virtual async Task AuthenticateAsync(string tenantId, string clientId, X509Certificate2 clientCertificate, string[] scopes, CancellationToken cancellationToken = default) { - using (Request request = CreateClientCertificateAuthRequest(tenantId, clientId, clientCertificate, scopes)) + using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.Identity.AadIdentityClient.Authenticate"); + scope.Start(); + + try { + using Request request = CreateClientCertificateAuthRequest(tenantId, clientId, clientCertificate, scopes); try { return await SendAuthRequestAsync(request, cancellationToken).ConfigureAwait(false); @@ -83,12 +109,21 @@ public virtual async Task AuthenticateAsync(string tenantId, string throw new AuthenticationFailedException(AuthenticationRequestFailedError, ex); } } + catch (Exception e) + { + scope.Failed(e); + throw; + } } public virtual AccessToken Authenticate(string tenantId, string clientId, X509Certificate2 clientCertificate, string[] scopes, CancellationToken cancellationToken = default) { - using (Request request = CreateClientCertificateAuthRequest(tenantId, clientId, clientCertificate, scopes)) + using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.Identity.AadIdentityClient.Authenticate"); + scope.Start(); + + try { + using Request request = CreateClientCertificateAuthRequest(tenantId, clientId, clientCertificate, scopes); try { return SendAuthRequest(request, cancellationToken); @@ -98,7 +133,13 @@ public virtual AccessToken Authenticate(string tenantId, string clientId, X509Ce throw new AuthenticationFailedException(AuthenticationRequestFailedError, ex); } } + catch (Exception e) + { + scope.Failed(e); + throw; + } } + private async Task SendAuthRequestAsync(Request request, CancellationToken cancellationToken) { var response = await _pipeline.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); diff --git a/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs b/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs index a76fbc7a26f90..6a3bd6e336cbc 100644 --- a/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs +++ b/sdk/identity/Azure.Identity/src/ManagedIdentityClient.cs @@ -38,6 +38,10 @@ internal class ManagedIdentityClient private readonly IdentityClientOptions _options; private readonly HttpPipeline _pipeline; + protected ManagedIdentityClient() + { + } + public ManagedIdentityClient(IdentityClientOptions options = null) { _options = options ?? new IdentityClientOptions(); @@ -66,13 +70,24 @@ public virtual async Task AuthenticateAsync(string[] scopes, string return default; } + using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.Identity.ManagedIdentityClient.Authenticate"); + scope.Start(); + try { - return await SendAuthRequestAsync(msiType, scopes, clientId, cancellationToken).ConfigureAwait(false); + try + { + return await SendAuthRequestAsync(msiType, scopes, clientId, cancellationToken).ConfigureAwait(false); + } + catch (RequestFailedException ex) + { + throw new AuthenticationFailedException(AuthenticationRequestFailedError, ex); + } } - catch (RequestFailedException ex) + catch (Exception e) { - throw new AuthenticationFailedException(AuthenticationRequestFailedError, ex); + scope.Failed(e); + throw; } } @@ -86,13 +101,24 @@ public virtual AccessToken Authenticate(string[] scopes, string clientId = null, return default; } + using DiagnosticScope scope = _pipeline.Diagnostics.CreateScope("Azure.Identity.ManagedIdentityClient.Authenticate"); + scope.Start(); + try { - return SendAuthRequest(msiType, scopes, clientId, cancellationToken); + try + { + return SendAuthRequest(msiType, scopes, clientId, cancellationToken); + } + catch(RequestFailedException ex) + { + throw new AuthenticationFailedException(AuthenticationRequestFailedError, ex); + } } - catch(RequestFailedException ex) + catch (Exception e) { - throw new AuthenticationFailedException(AuthenticationRequestFailedError, ex); + scope.Failed(e); + throw; } } diff --git a/sdk/identity/Azure.Identity/src/Properties/AssemblyInfo.cs b/sdk/identity/Azure.Identity/src/Properties/AssemblyInfo.cs index c9aee09068646..0d66b411c4431 100644 --- a/sdk/identity/Azure.Identity/src/Properties/AssemblyInfo.cs +++ b/sdk/identity/Azure.Identity/src/Properties/AssemblyInfo.cs @@ -1,3 +1,4 @@ using System.Runtime.CompilerServices; [assembly: InternalsVisibleTo("Azure.Identity.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100d15ddcb29688295338af4b7686603fe614abd555e09efba8fb88ee09e1f7b1ccaeed2e8f823fa9eef3fdd60217fc012ea67d2479751a0b8c087a4185541b851bd8b16f8d91b840e51b1cb0ba6fe647997e57429265e85ef62d565db50a69ae1647d54d7bd855e4db3d8a91510e5bcbd0edfbbecaa20a7bd9ae74593daa7b11b4")] +[assembly: InternalsVisibleTo("DynamicProxyGenAssembly2, PublicKey=0024000004800000940000000602000000240000525341310004000001000100c547cac37abd99c8db225ef2f6c8a3602f3b3606cc9891605d02baa56104f4cfc0734aa39b93bf7852f7d9266654753cc297e7d2edfe0bac1cdcf9f717241550e0a7b191195b7667bb4f64bcb8e2121380fd1d9d46ad2d92d2d15605093924cceaf74c4861eff62abf69b9291ed0a340e113be11e6a7d3113e92484cf7045cc7")] \ No newline at end of file diff --git a/sdk/identity/Azure.Identity/tests/ManagedIdentityClientTests.cs b/sdk/identity/Azure.Identity/tests/ManagedIdentityClientTests.cs new file mode 100644 index 0000000000000..02ba0db4fa2b8 --- /dev/null +++ b/sdk/identity/Azure.Identity/tests/ManagedIdentityClientTests.cs @@ -0,0 +1,277 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Testing; +using Azure.Identity.Tests.Mock; +using NUnit.Framework; + +namespace Azure.Identity.Tests +{ + public class ManagedIdentityClientTests : ClientTestBase + { + [SetUp] + public void ResetManagedIdenityClient() + { + typeof(ManagedIdentityClient).GetField("s_msiType", BindingFlags.NonPublic | BindingFlags.Static).SetValue(null, 0); + typeof(ManagedIdentityClient).GetField("s_endpoint", BindingFlags.NonPublic | BindingFlags.Static).SetValue(null, null); + } + + public ManagedIdentityClientTests(bool isAsync) : base(isAsync) + { + } + + [NonParallelizable] + [Test] + public async Task VerifyImdsRequestMockAsync() + { + using (new TestEnvVar("MSI_ENDPOINT", null)) + using (new TestEnvVar("MSI_SECRET", null)) + { + var response = new MockResponse(200); + + var expectedToken = "mock-msi-access-token"; + + response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": \"3600\" }}"); + + var mockTransport = new MockTransport(response, response); + + var options = new IdentityClientOptions() { Transport = mockTransport }; + + var client = InstrumentClient(new ManagedIdentityClient(options)); + + AccessToken actualToken = await client.AuthenticateAsync(MockScopes.Default); + + Assert.AreEqual(expectedToken, actualToken.Token); + + MockRequest request = mockTransport.Requests[mockTransport.Requests.Count - 1]; + + string query = request.UriBuilder.Query; + + Assert.IsTrue(query.Contains("api-version=2018-02-01")); + + Assert.IsTrue(query.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); + + Assert.IsTrue(request.Headers.TryGetValue("Metadata", out string metadataValue)); + + Assert.AreEqual("true", metadataValue); + } + } + + [NonParallelizable] + [Test] + public async Task VerifyImdsRequestWithClientIdMockAsync() + { + using (new TestEnvVar("MSI_ENDPOINT", null)) + using (new TestEnvVar("MSI_SECRET", null)) + { + var response = new MockResponse(200); + + var expectedToken = "mock-msi-access-token"; + + response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": \"3600\" }}"); + + var mockTransport = new MockTransport(response, response); + + var options = new IdentityClientOptions() { Transport = mockTransport }; + + var client = InstrumentClient(new ManagedIdentityClient(options)); + + AccessToken actualToken = await client.AuthenticateAsync(MockScopes.Default, clientId: "mock-client-id"); + + Assert.AreEqual(expectedToken, actualToken.Token); + + MockRequest request = mockTransport.Requests[mockTransport.Requests.Count - 1]; + + string query = request.UriBuilder.Query; + + Assert.IsTrue(query.Contains("api-version=2018-02-01")); + + Assert.IsTrue(query.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); + + Assert.IsTrue(query.Contains($"client_id=mock-client-id")); + + Assert.IsTrue(request.Headers.TryGetValue("Metadata", out string metadataValue)); + + Assert.AreEqual("true", metadataValue); + } + } + + [NonParallelizable] + [Test] + public async Task VerifyAppServiceMsiRequestMockAsync() + { + using (new TestEnvVar("MSI_ENDPOINT", "https://mock.msi.endpoint/")) + using (new TestEnvVar("MSI_SECRET", "mock-msi-secret")) + { + var response = new MockResponse(200); + + var expectedToken = "mock-msi-access-token"; + + response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": \"{DateTimeOffset.UtcNow.ToString()}\" }}"); + + var mockTransport = new MockTransport(response); + + var options = new IdentityClientOptions() { Transport = mockTransport }; + + var client = InstrumentClient(new ManagedIdentityClient(options)); + + AccessToken actualToken = await client.AuthenticateAsync(MockScopes.Default); + + Assert.AreEqual(expectedToken, actualToken.Token); + + MockRequest request = mockTransport.Requests[0]; + + Assert.IsTrue(request.UriBuilder.ToString().StartsWith("https://mock.msi.endpoint/")); + + string query = request.UriBuilder.Query; + + Assert.IsTrue(query.Contains("api-version=2017-09-01")); + + Assert.IsTrue(query.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); + + Assert.IsTrue(request.Headers.TryGetValue("secret", out string actSecretValue)); + + Assert.AreEqual("mock-msi-secret", actSecretValue); + } + } + + [NonParallelizable] + [Test] + public async Task VerifyAppServiceMsiRequestWithClientIdMockAsync() + { + using (new TestEnvVar("MSI_ENDPOINT", "https://mock.msi.endpoint/")) + using (new TestEnvVar("MSI_SECRET", "mock-msi-secret")) + { + var response = new MockResponse(200); + + var expectedToken = "mock-msi-access-token"; + + response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": \"{DateTimeOffset.UtcNow.ToString()}\" }}"); + + var mockTransport = new MockTransport(response); + + var options = new IdentityClientOptions() { Transport = mockTransport }; + + var client = InstrumentClient(new ManagedIdentityClient(options)); + + AccessToken actualToken = await client.AuthenticateAsync(MockScopes.Default, "mock-client-id"); + + Assert.AreEqual(expectedToken, actualToken.Token); + + MockRequest request = mockTransport.Requests[0]; + + Assert.IsTrue(request.UriBuilder.ToString().StartsWith("https://mock.msi.endpoint/")); + + string query = request.UriBuilder.Query; + + Assert.IsTrue(query.Contains("api-version=2017-09-01")); + + Assert.IsTrue(query.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); + + Assert.IsTrue(query.Contains($"client_id=mock-client-id")); + + Assert.IsTrue(request.Headers.TryGetValue("secret", out string actSecretValue)); + + Assert.AreEqual("mock-msi-secret", actSecretValue); + } + } + + [NonParallelizable] + [Test] + public async Task VerifyCloudShellMsiRequestMockAsync() + { + using (new TestEnvVar("MSI_ENDPOINT", "https://mock.msi.endpoint/")) + using (new TestEnvVar("MSI_SECRET", null)) + { + var response = new MockResponse(200); + + var expectedToken = "mock-msi-access-token"; + + response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": {(DateTimeOffset.UtcNow + TimeSpan.FromSeconds(3600)).ToUnixTimeSeconds()} }}"); + + var mockTransport = new MockTransport(response); + + var options = new IdentityClientOptions() { Transport = mockTransport }; + + var client = InstrumentClient(new ManagedIdentityClient(options)); + + AccessToken actualToken = await client.AuthenticateAsync(MockScopes.Default); + + Assert.AreEqual(expectedToken, actualToken.Token); + + MockRequest request = mockTransport.Requests[0]; + + Assert.IsTrue(request.UriBuilder.ToString().StartsWith("https://mock.msi.endpoint/")); + + Assert.IsTrue(request.Content.TryComputeLength(out long contentLen)); + + var content = new byte[contentLen]; + + MemoryStream contentBuff = new MemoryStream(content); + + request.Content.WriteTo(contentBuff, default); + + string body = Encoding.UTF8.GetString(content); + + Assert.IsTrue(body.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); + + Assert.IsTrue(request.Headers.TryGetValue("Metadata", out string actMetadata)); + + Assert.AreEqual("true", actMetadata); + } + } + + [NonParallelizable] + [Test] + public async Task VerifyCloudShellMsiRequestWithClientIdMockAsync() + { + using (new TestEnvVar("MSI_ENDPOINT", "https://mock.msi.endpoint/")) + using (new TestEnvVar("MSI_SECRET", null)) + { + var response = new MockResponse(200); + + var expectedToken = "mock-msi-access-token"; + + response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": {(DateTimeOffset.UtcNow + TimeSpan.FromSeconds(3600)).ToUnixTimeSeconds()} }}"); + + var mockTransport = new MockTransport(response); + + var options = new IdentityClientOptions() { Transport = mockTransport }; + + var client = InstrumentClient(new ManagedIdentityClient(options)); + + AccessToken actualToken = await client.AuthenticateAsync(MockScopes.Default, "mock-client-id"); + + Assert.AreEqual(expectedToken, actualToken.Token); + + MockRequest request = mockTransport.Requests[0]; + + Assert.IsTrue(request.UriBuilder.ToString().StartsWith("https://mock.msi.endpoint/")); + + Assert.IsTrue(request.Content.TryComputeLength(out long contentLen)); + + var content = new byte[contentLen]; + + MemoryStream contentBuff = new MemoryStream(content); + + request.Content.WriteTo(contentBuff, default); + + string body = Encoding.UTF8.GetString(content); + + Assert.IsTrue(body.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); + + Assert.IsTrue(body.Contains($"client_id=mock-client-id")); + + Assert.IsTrue(request.Headers.TryGetValue("Metadata", out string actMetadata)); + + Assert.AreEqual("true", actMetadata); + } + } + } +} \ No newline at end of file diff --git a/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs b/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs index 72ef90825d357..e8396946fa5a2 100644 --- a/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs @@ -1,18 +1,19 @@ -using Azure.Core; +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + using Azure.Core.Testing; -using Azure.Identity.Tests.Mock; using NUnit.Framework; -using System; -using System.Collections.Generic; -using System.Text; using System.Threading.Tasks; using System.Reflection; -using System.IO; namespace Azure.Identity.Tests { - public class ManagedIdentityCredentialTests + public class ManagedIdentityCredentialTests: ClientTestBase { + public ManagedIdentityCredentialTests(bool isAsync) : base(isAsync) + { + } + [SetUp] public void ResetManagedIdenityClient() { @@ -30,253 +31,5 @@ public async Task GetSystemTokenLiveAsync() Assert.IsNotNull(token.Token); } - - [NonParallelizable] - [Test] - public async Task VerifyImdsRequestMockAsync() - { - using (new TestEnvVar("MSI_ENDPOINT", null)) - using (new TestEnvVar("MSI_SECRET", null)) - { - var response = new MockResponse(200); - - var expectedToken = "mock-msi-access-token"; - - response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": \"3600\" }}"); - - var mockTransport = new MockTransport(response, response); - - var options = new IdentityClientOptions() { Transport = mockTransport }; - - var credential = new ManagedIdentityCredential(options: options); - - AccessToken actualToken = await credential.GetTokenAsync(MockScopes.Default); - - Assert.AreEqual(expectedToken, actualToken.Token); - - MockRequest request = mockTransport.Requests[mockTransport.Requests.Count - 1]; - - string query = request.UriBuilder.Query; - - Assert.IsTrue(query.Contains("api-version=2018-02-01")); - - Assert.IsTrue(query.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); - - Assert.IsTrue(request.Headers.TryGetValue("Metadata", out string metadataValue)); - - Assert.AreEqual("true", metadataValue); - } - } - - [NonParallelizable] - [Test] - public async Task VerifyImdsRequestWithClientIdMockAsync() - { - using (new TestEnvVar("MSI_ENDPOINT", null)) - using (new TestEnvVar("MSI_SECRET", null)) - { - var response = new MockResponse(200); - - var expectedToken = "mock-msi-access-token"; - - response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": \"3600\" }}"); - - var mockTransport = new MockTransport(response, response); - - var options = new IdentityClientOptions() { Transport = mockTransport }; - - var credential = new ManagedIdentityCredential(clientId: "mock-client-id", options: options); - - AccessToken actualToken = await credential.GetTokenAsync(MockScopes.Default); - - Assert.AreEqual(expectedToken, actualToken.Token); - - MockRequest request = mockTransport.Requests[mockTransport.Requests.Count - 1]; - - string query = request.UriBuilder.Query; - - Assert.IsTrue(query.Contains("api-version=2018-02-01")); - - Assert.IsTrue(query.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); - - Assert.IsTrue(query.Contains($"client_id=mock-client-id")); - - Assert.IsTrue(request.Headers.TryGetValue("Metadata", out string metadataValue)); - - Assert.AreEqual("true", metadataValue); - } - } - - [NonParallelizable] - [Test] - public async Task VerifyAppServiceMsiRequestMockAsync() - { - using (new TestEnvVar("MSI_ENDPOINT", "https://mock.msi.endpoint/")) - using (new TestEnvVar("MSI_SECRET", "mock-msi-secret")) - { - var response = new MockResponse(200); - - var expectedToken = "mock-msi-access-token"; - - response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": \"{DateTimeOffset.UtcNow.ToString()}\" }}"); - - var mockTransport = new MockTransport(response); - - var options = new IdentityClientOptions() { Transport = mockTransport }; - - var credential = new ManagedIdentityCredential(options: options); - - AccessToken actualToken = await credential.GetTokenAsync(MockScopes.Default); - - Assert.AreEqual(expectedToken, actualToken.Token); - - MockRequest request = mockTransport.Requests[0]; - - Assert.IsTrue(request.UriBuilder.ToString().StartsWith("https://mock.msi.endpoint/")); - - string query = request.UriBuilder.Query; - - Assert.IsTrue(query.Contains("api-version=2017-09-01")); - - Assert.IsTrue(query.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); - - Assert.IsTrue(request.Headers.TryGetValue("secret", out string actSecretValue)); - - Assert.AreEqual("mock-msi-secret", actSecretValue); - } - } - - [NonParallelizable] - [Test] - public async Task VerifyAppServiceMsiRequestWithClientIdMockAsync() - { - using (new TestEnvVar("MSI_ENDPOINT", "https://mock.msi.endpoint/")) - using (new TestEnvVar("MSI_SECRET", "mock-msi-secret")) - { - var response = new MockResponse(200); - - var expectedToken = "mock-msi-access-token"; - - response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": \"{DateTimeOffset.UtcNow.ToString()}\" }}"); - - var mockTransport = new MockTransport(response); - - var options = new IdentityClientOptions() { Transport = mockTransport }; - - var credential = new ManagedIdentityCredential(clientId: "mock-client-id", options: options); - - AccessToken actualToken = await credential.GetTokenAsync(MockScopes.Default); - - Assert.AreEqual(expectedToken, actualToken.Token); - - MockRequest request = mockTransport.Requests[0]; - - Assert.IsTrue(request.UriBuilder.ToString().StartsWith("https://mock.msi.endpoint/")); - - string query = request.UriBuilder.Query; - - Assert.IsTrue(query.Contains("api-version=2017-09-01")); - - Assert.IsTrue(query.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); - - Assert.IsTrue(query.Contains($"client_id=mock-client-id")); - - Assert.IsTrue(request.Headers.TryGetValue("secret", out string actSecretValue)); - - Assert.AreEqual("mock-msi-secret", actSecretValue); - } - } - - [NonParallelizable] - [Test] - public async Task VerifyCloudShellMsiRequestMockAsync() - { - using (new TestEnvVar("MSI_ENDPOINT", "https://mock.msi.endpoint/")) - using (new TestEnvVar("MSI_SECRET", null)) - { - var response = new MockResponse(200); - - var expectedToken = "mock-msi-access-token"; - - response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": {(DateTimeOffset.UtcNow + TimeSpan.FromSeconds(3600)).ToUnixTimeSeconds()} }}"); - - var mockTransport = new MockTransport(response); - - var options = new IdentityClientOptions() { Transport = mockTransport }; - - var credential = new ManagedIdentityCredential(options: options); - - AccessToken actualToken = await credential.GetTokenAsync(MockScopes.Default); - - Assert.AreEqual(expectedToken, actualToken.Token); - - MockRequest request = mockTransport.Requests[0]; - - Assert.IsTrue(request.UriBuilder.ToString().StartsWith("https://mock.msi.endpoint/")); - - Assert.IsTrue(request.Content.TryComputeLength(out long contentLen)); - - var content = new byte[contentLen]; - - MemoryStream contentBuff = new MemoryStream(content); - - request.Content.WriteTo(contentBuff, default); - - string body = Encoding.UTF8.GetString(content); - - Assert.IsTrue(body.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); - - Assert.IsTrue(request.Headers.TryGetValue("Metadata", out string actMetadata)); - - Assert.AreEqual("true", actMetadata); - } - } - - [NonParallelizable] - [Test] - public async Task VerifyCloudShellMsiRequestWithClientIdMockAsync() - { - using (new TestEnvVar("MSI_ENDPOINT", "https://mock.msi.endpoint/")) - using (new TestEnvVar("MSI_SECRET", null)) - { - var response = new MockResponse(200); - - var expectedToken = "mock-msi-access-token"; - - response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_on\": {(DateTimeOffset.UtcNow + TimeSpan.FromSeconds(3600)).ToUnixTimeSeconds()} }}"); - - var mockTransport = new MockTransport(response); - - var options = new IdentityClientOptions() { Transport = mockTransport }; - - var credential = new ManagedIdentityCredential(clientId: "mock-client-id", options: options); - - AccessToken actualToken = await credential.GetTokenAsync(MockScopes.Default); - - Assert.AreEqual(expectedToken, actualToken.Token); - - MockRequest request = mockTransport.Requests[0]; - - Assert.IsTrue(request.UriBuilder.ToString().StartsWith("https://mock.msi.endpoint/")); - - Assert.IsTrue(request.Content.TryComputeLength(out long contentLen)); - - var content = new byte[contentLen]; - - MemoryStream contentBuff = new MemoryStream(content); - - request.Content.WriteTo(contentBuff, default); - - string body = Encoding.UTF8.GetString(content); - - Assert.IsTrue(body.Contains($"resource={Uri.EscapeDataString(ScopeUtilities.ScopesToResource(MockScopes.Default))}")); - - Assert.IsTrue(body.Contains($"client_id=mock-client-id")); - - Assert.IsTrue(request.Headers.TryGetValue("Metadata", out string actMetadata)); - - Assert.AreEqual("true", actMetadata); - } - } } } diff --git a/sdk/identity/Azure.Identity/tests/Mock/AadIdentityClientTests.cs b/sdk/identity/Azure.Identity/tests/Mock/AadIdentityClientTests.cs new file mode 100644 index 0000000000000..ea959ee5a2a0a --- /dev/null +++ b/sdk/identity/Azure.Identity/tests/Mock/AadIdentityClientTests.cs @@ -0,0 +1,179 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Testing; +using NUnit.Framework; + +namespace Azure.Identity.Tests.Mock +{ + public class AadIdentityClientTests : ClientTestBase + { + public AadIdentityClientTests(bool isAsync) : base(isAsync) + { + } + + [Test] + public async Task VerifyClientClientSecretRequestAsync() + { + var response = new MockResponse(200); + + var expectedToken = "mock-msi-access-token"; + + response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_in\": 3600 }}"); + + var mockTransport = new MockTransport(response); + + var options = new IdentityClientOptions() { Transport = mockTransport }; + + var expectedTenantId = Guid.NewGuid().ToString(); + + var expectedClientId = Guid.NewGuid().ToString(); + + var expectedClientSecret = "secret"; + + var client = InstrumentClient(new AadIdentityClient(options: options)); + + AccessToken actualToken = await client.AuthenticateAsync(expectedTenantId, expectedClientId, expectedClientSecret, MockScopes.Default); + + Assert.AreEqual(expectedToken, actualToken.Token); + + MockRequest request = mockTransport.SingleRequest; + + Assert.IsTrue(request.Content.TryComputeLength(out long contentLen)); + + var content = new byte[contentLen]; + + await request.Content.WriteToAsync(new MemoryStream(content), default); + + Assert.IsTrue(TryParseFormEncodedBody(content, out Dictionary parsedBody)); + + Assert.IsTrue(parsedBody.TryGetValue("response_type", out string responseType) && responseType == "token"); + + Assert.IsTrue(parsedBody.TryGetValue("grant_type", out string grantType) && grantType == "client_credentials"); + + Assert.IsTrue(parsedBody.TryGetValue("client_id", out string actualClientId) && actualClientId == expectedClientId); + + Assert.IsTrue(parsedBody.TryGetValue("client_secret", out string actualClientSecret) && actualClientSecret == "secret"); + + Assert.IsTrue(parsedBody.TryGetValue("scope", out string actualScope) && actualScope == MockScopes.Default.ToString()); + } + + [Test] + public async Task VerifyClientCertificateRequestAsync() + { + var response = new MockResponse(200); + + var expectedToken = "mock-msi-access-token"; + + response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_in\": 3600 }}"); + + var mockTransport = new MockTransport(response); + + var options = new IdentityClientOptions() { Transport = mockTransport }; + + var expectedTenantId = Guid.NewGuid().ToString(); + + var expectedClientId = Guid.NewGuid().ToString(); + + var mockCert = new X509Certificate2("./Data/cert.pfx", "password"); + + var client = InstrumentClient(new AadIdentityClient(options: options)); + + AccessToken actualToken = await client.AuthenticateAsync(expectedTenantId, expectedClientId, mockCert, MockScopes.Default); + + Assert.AreEqual(expectedToken, actualToken.Token); + + MockRequest request = mockTransport.SingleRequest; + + Assert.IsTrue(request.Content.TryComputeLength(out long contentLen)); + + var content = new byte[contentLen]; + + await request.Content.WriteToAsync(new MemoryStream(content), default); + + Assert.IsTrue(TryParseFormEncodedBody(content, out Dictionary parsedBody)); + + Assert.IsTrue(parsedBody.TryGetValue("response_type", out string responseType) && responseType == "token"); + + Assert.IsTrue(parsedBody.TryGetValue("grant_type", out string grantType) && grantType == "client_credentials"); + + Assert.IsTrue(parsedBody.TryGetValue("client_assertion_type", out string assertionType) && assertionType == "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); + + Assert.IsTrue(parsedBody.TryGetValue("client_id", out string actualClientId) && actualClientId == expectedClientId); + + Assert.IsTrue(parsedBody.TryGetValue("scope", out string actualScope) && actualScope == MockScopes.Default.ToString()); + + Assert.IsTrue(parsedBody.TryGetValue("client_assertion", out string clientAssertion)); + + // var header + VerifyClientAssertion(clientAssertion, expectedTenantId, expectedClientId, mockCert); + } + + public void VerifyClientAssertion(string clientAssertion, string expectedTenantId, string expectedClientId, X509Certificate2 clientCertificate) + { + var splitAssertion = clientAssertion.Split('.'); + + Assert.IsTrue(splitAssertion.Length == 3); + + var compactHeader = splitAssertion[0]; + var compactPayload = splitAssertion[1]; + var encodedSignature = splitAssertion[2]; + + // verify the JWT header + using (JsonDocument json = JsonDocument.Parse(Base64Url.Decode(compactHeader))) + { + Assert.IsTrue(json.RootElement.TryGetProperty("typ", out JsonElement typProp) && typProp.GetString() == "JWT"); + Assert.IsTrue(json.RootElement.TryGetProperty("alg", out JsonElement algProp) && algProp.GetString() == "RS256"); + Assert.IsTrue(json.RootElement.TryGetProperty("x5t", out JsonElement x5tProp) && x5tProp.GetString() == Base64Url.HexToBase64Url(clientCertificate.Thumbprint)); + } + + // verify the JWT payload + using (JsonDocument json = JsonDocument.Parse(Base64Url.Decode(compactPayload))) + { + Assert.IsTrue(json.RootElement.TryGetProperty("aud", out JsonElement audProp) && audProp.GetString() == $"https://login.microsoftonline.com/{expectedTenantId}/oauth2/v2.0/token"); + Assert.IsTrue(json.RootElement.TryGetProperty("iss", out JsonElement issProp) && issProp.GetString() == expectedClientId); + Assert.IsTrue(json.RootElement.TryGetProperty("sub", out JsonElement subProp) && subProp.GetString() == expectedClientId); + Assert.IsTrue(json.RootElement.TryGetProperty("nbf", out JsonElement nbfProp) && nbfProp.GetInt64() <= DateTimeOffset.UtcNow.ToUnixTimeSeconds()); + Assert.IsTrue(json.RootElement.TryGetProperty("exp", out JsonElement expProp) && expProp.GetInt64() > DateTimeOffset.UtcNow.ToUnixTimeSeconds()); ; + } + + // verify the JWT signature + Assert.IsTrue(clientCertificate.GetRSAPublicKey().VerifyData(Encoding.ASCII.GetBytes(compactHeader + "." + compactPayload), Base64Url.Decode(encodedSignature), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1)); + } + + public bool TryParseFormEncodedBody(byte[] content, out Dictionary parsed) + { + parsed = new Dictionary(); + + var contentStr = Encoding.UTF8.GetString(content); + + foreach (string parameter in contentStr.Split('&')) + { + if (string.IsNullOrEmpty(parameter)) + { + return false; + } + + var splitParam = parameter.Split('='); + + if(splitParam.Length != 2) + { + return false; + } + + parsed[splitParam[0]] = Uri.UnescapeDataString(splitParam[1]); + } + + return true; + } + } +} \ No newline at end of file diff --git a/sdk/identity/Azure.Identity/tests/Mock/MockClientCertificateCredentialTests.cs b/sdk/identity/Azure.Identity/tests/Mock/MockClientCertificateCredentialTests.cs index a0baf136d87e4..9d9063eb32524 100644 --- a/sdk/identity/Azure.Identity/tests/Mock/MockClientCertificateCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/Mock/MockClientCertificateCredentialTests.cs @@ -1,14 +1,6 @@ -using Azure.Core; -using Azure.Core.Testing; -using NUnit.Framework; +using NUnit.Framework; using System; -using System.Collections.Generic; -using System.IO; -using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; -using System.Text; -using System.Text.Json; -using System.Threading.Tasks; namespace Azure.Identity.Tests.Mock { @@ -27,114 +19,5 @@ public void VerifyCtorErrorHandling() Assert.Throws(() => new ClientCertificateCredential(tenantId, null, clientCertificate)); Assert.Throws(() => new ClientCertificateCredential(tenantId, clientId, null)); } - - [Test] - public async Task VerifyClientCertificateRequestAsync() - { - var response = new MockResponse(200); - - var expectedToken = "mock-msi-access-token"; - - response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_in\": 3600 }}"); - - var mockTransport = new MockTransport(response); - - var options = new IdentityClientOptions() { Transport = mockTransport }; - - var expectedTenantId = Guid.NewGuid().ToString(); - - var expectedClientId = Guid.NewGuid().ToString(); - - var mockCert = new X509Certificate2("./Data/cert.pfx", "password"); - - var credential = new ClientCertificateCredential(expectedTenantId, expectedClientId, mockCert, options: options); - - AccessToken actualToken = await credential.GetTokenAsync(MockScopes.Default); - - Assert.AreEqual(expectedToken, actualToken.Token); - - MockRequest request = mockTransport.SingleRequest; - - Assert.IsTrue(request.Content.TryComputeLength(out long contentLen)); - - var content = new byte[contentLen]; - - await request.Content.WriteToAsync(new MemoryStream(content), default); - - Assert.IsTrue(TryParseFormEncodedBody(content, out Dictionary parsedBody)); - - Assert.IsTrue(parsedBody.TryGetValue("response_type", out string responseType) && responseType == "token"); - - Assert.IsTrue(parsedBody.TryGetValue("grant_type", out string grantType) && grantType == "client_credentials"); - - Assert.IsTrue(parsedBody.TryGetValue("client_assertion_type", out string assertionType) && assertionType == "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); - - Assert.IsTrue(parsedBody.TryGetValue("client_id", out string actualClientId) && actualClientId == expectedClientId); - - Assert.IsTrue(parsedBody.TryGetValue("scope", out string actualScope) && actualScope == MockScopes.Default.ToString()); - - Assert.IsTrue(parsedBody.TryGetValue("client_assertion", out string clientAssertion)); - - // var header - VerifyClientAssertion(clientAssertion, expectedTenantId, expectedClientId, mockCert); - } - - public void VerifyClientAssertion(string clientAssertion, string expectedTenantId, string expectedClientId, X509Certificate2 clientCertificate) - { - var splitAssertion = clientAssertion.Split('.'); - - Assert.IsTrue(splitAssertion.Length == 3); - - var compactHeader = splitAssertion[0]; - var compactPayload = splitAssertion[1]; - var encodedSignature = splitAssertion[2]; - - // verify the JWT header - using (JsonDocument json = JsonDocument.Parse(Base64Url.Decode(compactHeader))) - { - Assert.IsTrue(json.RootElement.TryGetProperty("typ", out JsonElement typProp) && typProp.GetString() == "JWT"); - Assert.IsTrue(json.RootElement.TryGetProperty("alg", out JsonElement algProp) && algProp.GetString() == "RS256"); - Assert.IsTrue(json.RootElement.TryGetProperty("x5t", out JsonElement x5tProp) && x5tProp.GetString() == Base64Url.HexToBase64Url(clientCertificate.Thumbprint)); - } - - // verify the JWT payload - using (JsonDocument json = JsonDocument.Parse(Base64Url.Decode(compactPayload))) - { - Assert.IsTrue(json.RootElement.TryGetProperty("aud", out JsonElement audProp) && audProp.GetString() == $"https://login.microsoftonline.com/{expectedTenantId}/oauth2/v2.0/token"); - Assert.IsTrue(json.RootElement.TryGetProperty("iss", out JsonElement issProp) && issProp.GetString() == expectedClientId); - Assert.IsTrue(json.RootElement.TryGetProperty("sub", out JsonElement subProp) && subProp.GetString() == expectedClientId); - Assert.IsTrue(json.RootElement.TryGetProperty("nbf", out JsonElement nbfProp) && nbfProp.GetInt64() <= DateTimeOffset.UtcNow.ToUnixTimeSeconds()); - Assert.IsTrue(json.RootElement.TryGetProperty("exp", out JsonElement expProp) && expProp.GetInt64() > DateTimeOffset.UtcNow.ToUnixTimeSeconds()); ; - } - - // verify the JWT signature - Assert.IsTrue(clientCertificate.GetRSAPublicKey().VerifyData(Encoding.ASCII.GetBytes(compactHeader + "." + compactPayload), Base64Url.Decode(encodedSignature), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1)); - } - - public bool TryParseFormEncodedBody(byte[] content, out Dictionary parsed) - { - parsed = new Dictionary(); - - var contentStr = Encoding.UTF8.GetString(content); - - foreach (string parameter in contentStr.Split('&')) - { - if (string.IsNullOrEmpty(parameter)) - { - return false; - } - - var splitParam = parameter.Split('='); - - if(splitParam.Length != 2) - { - return false; - } - - parsed[splitParam[0]] = Uri.UnescapeDataString(splitParam[1]); - } - - return true; - } } }