diff --git a/sdk/identity/Azure.Identity/src/Base64Url.cs b/sdk/identity/Azure.Identity/src/Base64Url.cs new file mode 100644 index 0000000000000..c5f39055ed746 --- /dev/null +++ b/sdk/identity/Azure.Identity/src/Base64Url.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for +// license information. + +using System; +using System.Text; + +namespace Azure.Identity +{ + internal static class Base64Url + { + public static byte[] Decode(string str) + { + str = new StringBuilder(str).Replace('-', '+').Replace('_', '/').Append('=', (str.Length % 4 == 0) ? 0 : 4 - (str.Length % 4)).ToString(); + + return Convert.FromBase64String(str); + } + + public static string Encode(byte[] bytes) + { + return new StringBuilder(Convert.ToBase64String(bytes)).Replace('+', '-').Replace('/', '_').Replace("=", "").ToString(); + } + + + public static string HexToBase64Url(string hex) + { + byte[] bytes = new byte[hex.Length / 2]; + + for (int i = 0; i < hex.Length; i += 2) + bytes[i / 2] = Convert.ToByte(hex.Substring(i, 2), 16); + + return Base64Url.Encode(bytes); + } + } +} \ No newline at end of file diff --git a/sdk/identity/Azure.Identity/src/ClientCertificateCredential.cs b/sdk/identity/Azure.Identity/src/ClientCertificateCredential.cs new file mode 100644 index 0000000000000..750944599f0f8 --- /dev/null +++ b/sdk/identity/Azure.Identity/src/ClientCertificateCredential.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for +// license information. + +using Azure.Core; +using System; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace Azure.Identity +{ + public class ClientCertificateCredential : TokenCredential + { + private string _tenantId; + private string _clientId; + private X509Certificate2 _clientCertificate; + private IdentityClient _client; + + public ClientCertificateCredential(string tenantId, string clientId, X509Certificate2 clientCertificate) + : this(tenantId, clientId, clientCertificate, null) + { + } + + public ClientCertificateCredential(string tenantId, string clientId, X509Certificate2 clientCertificate, IdentityClientOptions options) + { + _tenantId = tenantId ?? throw new ArgumentNullException(nameof(tenantId)); + + _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); + + _clientCertificate = clientCertificate ?? throw new ArgumentNullException(nameof(clientCertificate)); + + _client = (options != null) ? new IdentityClient(options) : IdentityClient.SharedClient; + } + + public override AccessToken GetToken(string[] scopes, CancellationToken cancellationToken = default) + { + return _client.Authenticate(_tenantId, _clientId, _clientCertificate, scopes, cancellationToken); + } + + public override async Task GetTokenAsync(string[] scopes, CancellationToken cancellationToken = default) + { + return await _client.AuthenticateAsync(_tenantId, _clientId, _clientCertificate, scopes, cancellationToken); + } + } +} diff --git a/sdk/identity/Azure.Identity/src/IdentityClient.cs b/sdk/identity/Azure.Identity/src/IdentityClient.cs index f4d3f01a81a63..c61aef85f238f 100644 --- a/sdk/identity/Azure.Identity/src/IdentityClient.cs +++ b/sdk/identity/Azure.Identity/src/IdentityClient.cs @@ -9,6 +9,8 @@ using System.Collections.Generic; using System.IO; using System.Runtime.CompilerServices; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; using System.Text; using System.Text.Json; using System.Threading; @@ -23,7 +25,8 @@ internal class IdentityClient private readonly IdentityClientOptions _options; private readonly HttpPipeline _pipeline; private readonly Uri ImdsEndptoint = new Uri("http://169.254.169.254/metadata/identity/oauth2/token"); - private readonly string MsiApiVersion = "2018-02-01"; + private const string MsiApiVersion = "2018-02-01"; + private const string ClientAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; public IdentityClient(IdentityClientOptions options = null) { @@ -69,6 +72,39 @@ public virtual AccessToken Authenticate(string tenantId, string clientId, string } } + public virtual async Task AuthenticateAsync(string tenantId, string clientId, X509Certificate2 clientCertificate, string[] scopes, CancellationToken cancellationToken = default) + { + using (Request request = CreateClientCertificateAuthRequest(tenantId, clientId, clientCertificate, scopes)) + { + var response = await _pipeline.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); + + if (response.Status == 200 || response.Status == 201) + { + var result = await DeserializeAsync(response.ContentStream, cancellationToken).ConfigureAwait(false); + + return new Response(response, result); + } + + throw await response.CreateRequestFailedExceptionAsync(); + } + } + + public virtual AccessToken Authenticate(string tenantId, string clientId, X509Certificate2 clientCertificate, string[] scopes, CancellationToken cancellationToken = default) + { + using (Request request = CreateClientCertificateAuthRequest(tenantId, clientId, clientCertificate, scopes)) + { + var response = _pipeline.SendRequest(request, cancellationToken); + + if (response.Status == 200 || response.Status == 201) + { + var result = Deserialize(response.ContentStream); + + return new Response(response, result); + } + + throw response.CreateRequestFailedException(); + } + } public virtual async Task AuthenticateManagedIdentityAsync(string[] scopes, string clientId = null, CancellationToken cancellationToken = default) { using (Request request = CreateManagedIdentityAuthRequest(scopes, clientId)) @@ -152,6 +188,77 @@ private Request CreateClientSecretAuthRequest(string tenantId, string clientId, return request; } + private Request CreateClientCertificateAuthRequest(string tenantId, string clientId, X509Certificate2 clientCertficate, string[] scopes) + { + Request request = _pipeline.CreateRequest(); + + request.Method = HttpPipelineMethod.Post; + + request.Headers.SetValue("Content-Type", "application/x-www-form-urlencoded"); + + request.UriBuilder.Uri = _options.AuthorityHost; + + request.UriBuilder.AppendPath(tenantId); + + request.UriBuilder.AppendPath("/oauth2/v2.0/token"); + + string clientAssertion = CreateClientAssertionJWT(clientId, request.UriBuilder.ToString(), clientCertficate); + + var bodyStr = $"response_type=token&grant_type=client_credentials&client_id={Uri.EscapeDataString(clientId)}&client_assertion_type={Uri.EscapeDataString(ClientAssertionType)}&client_assertion={Uri.EscapeDataString(clientAssertion)}&scope={Uri.EscapeDataString(string.Join(" ", scopes))}"; + + ReadOnlyMemory content = Encoding.UTF8.GetBytes(bodyStr).AsMemory(); + + request.Content = HttpPipelineRequestContent.Create(content); + + return request; + } + + private string CreateClientAssertionJWT(string clientId, string audience, X509Certificate2 clientCertificate) + { + var headerBuff = new ArrayBufferWriter(); + + using (var headerJson = new Utf8JsonWriter(headerBuff)) + { + headerJson.WriteStartObject(); + + headerJson.WriteString("typ", "JWT"); + headerJson.WriteString("alg", "RS256"); + headerJson.WriteString("x5t", Base64Url.HexToBase64Url(clientCertificate.Thumbprint)); + + headerJson.WriteEndObject(); + + headerJson.Flush(); + } + + var payloadBuff = new ArrayBufferWriter(); + + using (var payloadJson = new Utf8JsonWriter(payloadBuff)) + { + payloadJson.WriteStartObject(); + + payloadJson.WriteString("jti", Guid.NewGuid()); + payloadJson.WriteString("aud", audience); + payloadJson.WriteString("iss", clientId); + payloadJson.WriteString("sub", clientId); + payloadJson.WriteNumber("nbf", DateTimeOffset.UtcNow.ToUnixTimeSeconds()); + payloadJson.WriteNumber("exp", (DateTimeOffset.UtcNow + TimeSpan.FromMinutes(30)).ToUnixTimeSeconds()); + + payloadJson.WriteEndObject(); + + payloadJson.Flush(); + } + + string header = Base64Url.Encode(headerBuff.WrittenMemory.ToArray()); + + string payload = Base64Url.Encode(payloadBuff.WrittenMemory.ToArray()); + + string flattenedJws = header + "." + payload; + + byte[] signature = clientCertificate.GetRSAPrivateKey().SignData(Encoding.ASCII.GetBytes(flattenedJws), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + + return flattenedJws + "." + Base64Url.Encode(signature); + } + private async Task DeserializeAsync(Stream content, CancellationToken cancellationToken) { using (JsonDocument json = await JsonDocument.ParseAsync(content, default, cancellationToken).ConfigureAwait(false)) diff --git a/sdk/identity/Azure.Identity/tests/Azure.Identity.Tests.csproj b/sdk/identity/Azure.Identity/tests/Azure.Identity.Tests.csproj index 27463e596d5ab..d410e28ca09de 100644 --- a/sdk/identity/Azure.Identity/tests/Azure.Identity.Tests.csproj +++ b/sdk/identity/Azure.Identity/tests/Azure.Identity.Tests.csproj @@ -18,6 +18,7 @@ + diff --git a/sdk/identity/Azure.Identity/tests/Data/cert.pfx b/sdk/identity/Azure.Identity/tests/Data/cert.pfx new file mode 100644 index 0000000000000..52ed251dc2b5d Binary files /dev/null and b/sdk/identity/Azure.Identity/tests/Data/cert.pfx differ diff --git a/sdk/identity/Azure.Identity/tests/Mock/MockClientCertificateCredentialTests.cs b/sdk/identity/Azure.Identity/tests/Mock/MockClientCertificateCredentialTests.cs new file mode 100644 index 0000000000000..a0baf136d87e4 --- /dev/null +++ b/sdk/identity/Azure.Identity/tests/Mock/MockClientCertificateCredentialTests.cs @@ -0,0 +1,140 @@ +using Azure.Core; +using Azure.Core.Testing; +using NUnit.Framework; +using System; +using System.Collections.Generic; +using System.IO; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; + +namespace Azure.Identity.Tests.Mock +{ + public class MockClientCertificateCredentialTests + { + [Test] + public void VerifyCtorErrorHandling() + { + var clientCertificate = new X509Certificate2(@"./Data/cert.pfx", "password"); + + var tenantId = Guid.NewGuid().ToString(); + + var clientId = Guid.NewGuid().ToString(); + + Assert.Throws(() => new ClientCertificateCredential(null, clientId, clientCertificate)); + Assert.Throws(() => new ClientCertificateCredential(tenantId, null, clientCertificate)); + Assert.Throws(() => new ClientCertificateCredential(tenantId, clientId, null)); + } + + [Test] + public async Task VerifyClientCertificateRequestAsync() + { + var response = new MockResponse(200); + + var expectedToken = "mock-msi-access-token"; + + response.SetContent($"{{ \"access_token\": \"{expectedToken}\", \"expires_in\": 3600 }}"); + + var mockTransport = new MockTransport(response); + + var options = new IdentityClientOptions() { Transport = mockTransport }; + + var expectedTenantId = Guid.NewGuid().ToString(); + + var expectedClientId = Guid.NewGuid().ToString(); + + var mockCert = new X509Certificate2("./Data/cert.pfx", "password"); + + var credential = new ClientCertificateCredential(expectedTenantId, expectedClientId, mockCert, options: options); + + AccessToken actualToken = await credential.GetTokenAsync(MockScopes.Default); + + Assert.AreEqual(expectedToken, actualToken.Token); + + MockRequest request = mockTransport.SingleRequest; + + Assert.IsTrue(request.Content.TryComputeLength(out long contentLen)); + + var content = new byte[contentLen]; + + await request.Content.WriteToAsync(new MemoryStream(content), default); + + Assert.IsTrue(TryParseFormEncodedBody(content, out Dictionary parsedBody)); + + Assert.IsTrue(parsedBody.TryGetValue("response_type", out string responseType) && responseType == "token"); + + Assert.IsTrue(parsedBody.TryGetValue("grant_type", out string grantType) && grantType == "client_credentials"); + + Assert.IsTrue(parsedBody.TryGetValue("client_assertion_type", out string assertionType) && assertionType == "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); + + Assert.IsTrue(parsedBody.TryGetValue("client_id", out string actualClientId) && actualClientId == expectedClientId); + + Assert.IsTrue(parsedBody.TryGetValue("scope", out string actualScope) && actualScope == MockScopes.Default.ToString()); + + Assert.IsTrue(parsedBody.TryGetValue("client_assertion", out string clientAssertion)); + + // var header + VerifyClientAssertion(clientAssertion, expectedTenantId, expectedClientId, mockCert); + } + + public void VerifyClientAssertion(string clientAssertion, string expectedTenantId, string expectedClientId, X509Certificate2 clientCertificate) + { + var splitAssertion = clientAssertion.Split('.'); + + Assert.IsTrue(splitAssertion.Length == 3); + + var compactHeader = splitAssertion[0]; + var compactPayload = splitAssertion[1]; + var encodedSignature = splitAssertion[2]; + + // verify the JWT header + using (JsonDocument json = JsonDocument.Parse(Base64Url.Decode(compactHeader))) + { + Assert.IsTrue(json.RootElement.TryGetProperty("typ", out JsonElement typProp) && typProp.GetString() == "JWT"); + Assert.IsTrue(json.RootElement.TryGetProperty("alg", out JsonElement algProp) && algProp.GetString() == "RS256"); + Assert.IsTrue(json.RootElement.TryGetProperty("x5t", out JsonElement x5tProp) && x5tProp.GetString() == Base64Url.HexToBase64Url(clientCertificate.Thumbprint)); + } + + // verify the JWT payload + using (JsonDocument json = JsonDocument.Parse(Base64Url.Decode(compactPayload))) + { + Assert.IsTrue(json.RootElement.TryGetProperty("aud", out JsonElement audProp) && audProp.GetString() == $"https://login.microsoftonline.com/{expectedTenantId}/oauth2/v2.0/token"); + Assert.IsTrue(json.RootElement.TryGetProperty("iss", out JsonElement issProp) && issProp.GetString() == expectedClientId); + Assert.IsTrue(json.RootElement.TryGetProperty("sub", out JsonElement subProp) && subProp.GetString() == expectedClientId); + Assert.IsTrue(json.RootElement.TryGetProperty("nbf", out JsonElement nbfProp) && nbfProp.GetInt64() <= DateTimeOffset.UtcNow.ToUnixTimeSeconds()); + Assert.IsTrue(json.RootElement.TryGetProperty("exp", out JsonElement expProp) && expProp.GetInt64() > DateTimeOffset.UtcNow.ToUnixTimeSeconds()); ; + } + + // verify the JWT signature + Assert.IsTrue(clientCertificate.GetRSAPublicKey().VerifyData(Encoding.ASCII.GetBytes(compactHeader + "." + compactPayload), Base64Url.Decode(encodedSignature), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1)); + } + + public bool TryParseFormEncodedBody(byte[] content, out Dictionary parsed) + { + parsed = new Dictionary(); + + var contentStr = Encoding.UTF8.GetString(content); + + foreach (string parameter in contentStr.Split('&')) + { + if (string.IsNullOrEmpty(parameter)) + { + return false; + } + + var splitParam = parameter.Split('='); + + if(splitParam.Length != 2) + { + return false; + } + + parsed[splitParam[0]] = Uri.UnescapeDataString(splitParam[1]); + } + + return true; + } + } +}