Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Service Fabric MSI #729

Merged
merged 6 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ class AppServiceManagedIdentitySource extends AbstractManagedIdentitySource{
private static final String APP_SERVICE_MSI_API_VERSION = "2019-08-01";
private static final String SECRET_HEADER_NAME = "X-IDENTITY-HEADER";

private final URI MSI_ENDPOINT;
private final String SECRET;
private final URI msiEndpoint;
private final String identityHeader;

@Override
public void createManagedIdentityRequest(String resource) {
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
managedIdentityRequest.baseEndpoint = msiEndpoint;
managedIdentityRequest.method = HttpMethod.GET;

managedIdentityRequest.headers = new HashMap<>();
managedIdentityRequest.headers.put(SECRET_HEADER_NAME, SECRET);
managedIdentityRequest.headers.put(SECRET_HEADER_NAME, identityHeader);

managedIdentityRequest.queryParameters = new HashMap<>();
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(APP_SERVICE_MSI_API_VERSION));
Expand All @@ -50,8 +50,8 @@ public void createManagedIdentityRequest(String resource) {
private AppServiceManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI msiEndpoint, String secret)
{
super(msalRequest, serviceBundle, ManagedIdentitySourceType.APP_SERVICE);
this.MSI_ENDPOINT = msiEndpoint;
this.SECRET = secret;
this.msiEndpoint = msiEndpoint;
this.identityHeader = secret;
}

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ class CloudShellManagedIdentitySource extends AbstractManagedIdentitySource{

private static final Logger LOG = LoggerFactory.getLogger(CloudShellManagedIdentitySource.class);

private final URI MSI_ENDPOINT;
private final URI msiEndpoint;

@Override
public void createManagedIdentityRequest(String resource) {
managedIdentityRequest.baseEndpoint = MSI_ENDPOINT;
managedIdentityRequest.baseEndpoint = msiEndpoint;
managedIdentityRequest.method = HttpMethod.POST;

managedIdentityRequest.headers = new HashMap<>();
Expand All @@ -33,7 +33,7 @@ public void createManagedIdentityRequest(String resource) {
private CloudShellManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI msiEndpoint)
{
super(msalRequest, serviceBundle, ManagedIdentitySourceType.CLOUD_SHELL);
this.MSI_ENDPOINT = msiEndpoint;
this.msiEndpoint = msiEndpoint;

ManagedIdentityIdType idType =
((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
Expand All @@ -57,28 +57,23 @@ static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBund
return null;
}

URI validatedUri = validateAndGetUri(msiEndpoint);
return validatedUri == null ? null
: new CloudShellManagedIdentitySource(msalRequest, serviceBundle, validatedUri);
return new CloudShellManagedIdentitySource(msalRequest, serviceBundle, validateAndGetUri(msiEndpoint));
}

private static URI validateAndGetUri(String msiEndpoint)
{
URI endpointUri = null;

try
{
endpointUri = new URI(msiEndpoint);
URI endpointUri = new URI(msiEndpoint);
LOG.info("[Managed Identity] Environment variables validation passed for cloud shell managed identity. Endpoint URI: " + endpointUri + ". Creating cloud shell managed identity.");
return endpointUri;
}
catch (URISyntaxException ex)
{
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "MSI_ENDPOINT", msiEndpoint, "Cloud Shell"),
ManagedIdentitySourceType.CLOUD_SHELL);
}

LOG.info("[Managed Identity] Environment variables validation passed for cloud shell managed identity. Endpoint URI: " + endpointUri + ". Creating cloud shell managed identity.");
return endpointUri;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParameters par
private static AbstractManagedIdentitySource createManagedIdentitySource(MsalRequest msalRequest,
ServiceBundle serviceBundle) {
AbstractManagedIdentitySource managedIdentitySource;
if ((managedIdentitySource = AppServiceManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
if ((managedIdentitySource = ServiceFabricManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = AppServiceManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
} else if ((managedIdentitySource = CloudShellManagedIdentitySource.create(msalRequest, serviceBundle)) != null) {
return managedIdentitySource;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.HashMap;

class ServiceFabricManagedIdentitySource extends AbstractManagedIdentitySource {

private static final Logger LOG = LoggerFactory.getLogger(ServiceFabricManagedIdentitySource.class);

private static final String SERVICE_FABRIC_MSI_API_VERSION = "2019-07-01-preview";

private final URI msiEndpoint;
private final String identityHeader;
private final ManagedIdentityIdType idType;
private final String userAssignedId;

@Override
public void createManagedIdentityRequest(String resource) {
managedIdentityRequest.baseEndpoint = msiEndpoint;
managedIdentityRequest.method = HttpMethod.GET;

managedIdentityRequest.headers = new HashMap<>();
managedIdentityRequest.headers.put("secret", identityHeader);
Avery-Dunn marked this conversation as resolved.
Show resolved Hide resolved

managedIdentityRequest.queryParameters = new HashMap<>();
managedIdentityRequest.queryParameters.put("resource", Collections.singletonList(resource));
managedIdentityRequest.queryParameters.put("api-version", Collections.singletonList(SERVICE_FABRIC_MSI_API_VERSION));

if (idType == ManagedIdentityIdType.CLIENT_ID) {
LOG.info("[Managed Identity] Adding user assigned client id to the request for Service Fabric Managed Identity.");
managedIdentityRequest.queryParameters.put(Constants.MANAGED_IDENTITY_CLIENT_ID, Collections.singletonList(userAssignedId));
} else if (idType == ManagedIdentityIdType.RESOURCE_ID) {
LOG.info("[Managed Identity] Adding user assigned resource id to the request for Service Fabric Managed Identity.");
managedIdentityRequest.queryParameters.put(Constants.MANAGED_IDENTITY_RESOURCE_ID, Collections.singletonList(userAssignedId));
}
}

private ServiceFabricManagedIdentitySource(MsalRequest msalRequest, ServiceBundle serviceBundle, URI msiEndpoint, String identityHeader)
{
super(msalRequest, serviceBundle, ManagedIdentitySourceType.SERVICE_FABRIC);
this.msiEndpoint = msiEndpoint;
this.identityHeader = identityHeader;

this.idType = ((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getIdType();
this.userAssignedId = ((ManagedIdentityApplication) msalRequest.application()).getManagedIdentityId().getUserAssignedId();
}

static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBundle serviceBundle) {

IEnvironmentVariables environmentVariables = getEnvironmentVariables((ManagedIdentityParameters) msalRequest.requestContext().apiParameters());
String msiEndpoint = environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT);
String identityHeader = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT);
String identityServerThumbprint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_SERVER_THUMBPRINT);


if (StringHelper.isNullOrBlank(msiEndpoint) || StringHelper.isNullOrBlank(identityHeader) || StringHelper.isNullOrBlank(identityServerThumbprint))
{
LOG.info("[Managed Identity] Service fabric managed identity is unavailable.");
return null;
}

return new ServiceFabricManagedIdentitySource(msalRequest, serviceBundle, validateAndGetUri(msiEndpoint), identityHeader);
}

private static URI validateAndGetUri(String msiEndpoint)
{
try
{
URI endpointUri = new URI(msiEndpoint);
LOG.info("[Managed Identity] Environment variables validation passed for Service Fabric Managed Identity. Endpoint URI: " + endpointUri);
return endpointUri;
}
catch (URISyntaxException ex)
{
throw new MsalManagedIdentityException(MsalError.INVALID_MANAGED_IDENTITY_ENDPOINT, String.format(
MsalErrorMessage.MANAGED_IDENTITY_ENDPOINT_INVALID_URI_ERROR, "MSI_ENDPOINT", msiEndpoint, "Service Fabric"),
ManagedIdentitySourceType.SERVICE_FABRIC);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ public static Stream<Arguments> createData() {
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT,
ManagedIdentityTests.resourceDefaultSuffix),
Arguments.of(ManagedIdentitySourceType.IMDS, null,
ManagedIdentityTests.resource));
ManagedIdentityTests.resource),
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint,
ManagedIdentityTests.resource),
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint,
ManagedIdentityTests.resourceDefaultSuffix));
}

public static Stream<Arguments> createDataUserAssigned() {
Expand All @@ -42,6 +46,10 @@ public static Stream<Arguments> createDataUserAssigned() {
Arguments.of(ManagedIdentitySourceType.IMDS, null,
ManagedIdentityId.userAssignedClientId(CLIENT_ID)),
Arguments.of(ManagedIdentitySourceType.IMDS, null,
ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)),
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint,
ManagedIdentityId.userAssignedResourceId(CLIENT_ID)),
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint,
ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)));
}

Expand Down Expand Up @@ -74,6 +82,10 @@ public static Stream<Arguments> createDataWrongScope() {
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT,
"user.read"),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT,
"https://management.core.windows.net//user_impersonation"),
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint,
"user.read"),
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint,
"https://management.core.windows.net//user_impersonation"));
}

Expand All @@ -82,6 +94,7 @@ public static Stream<Arguments> createDataError() {
Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint),
Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint),
Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint),
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT));
Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT),
Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res
headers.put("Metadata", "true");
break;
}
case SERVICE_FABRIC:
endpoint = serviceFabricEndpoint;
queryParameters.put("api-version", Collections.singletonList("2019-07-01-preview"));
queryParameters.put("resource", Collections.singletonList(resource));
break;
}

switch (id.getIdType()) {
Expand Down Expand Up @@ -412,7 +417,7 @@ void azureArcManagedIdentity_MissingAuthHeader() throws Exception {

// Clear caching to avoid cross test pollution.
miApp.tokenCache().accessTokens.clear();

try {
miApp.acquireTokenForManagedIdentity(
ManagedIdentityParameters.builder(resource)
Expand Down Expand Up @@ -447,12 +452,12 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi

// Clear caching to avoid cross test pollution.
miApp.tokenCache().accessTokens.clear();

ManagedIdentityApplication miApp2 = ManagedIdentityApplication
.builder(ManagedIdentityId.systemAssigned())
.httpClient(httpClientMock)
.build();

IAuthenticationResult resultMiApp1 = miApp.acquireTokenForManagedIdentity(
ManagedIdentityParameters.builder(resource)
.environmentVariables(environmentVariables)
Expand Down