Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Client Assertion Support in OBO + Fix PWSH auth for Ubuntu #40552

Merged
merged 7 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import com.azure.identity.implementation.util.LoggingUtil;
import reactor.core.publisher.Mono;

import java.util.function.Supplier;

/**
* <p>On Behalf of authentication in Azure is a way for a user or application to authenticate to a service or resource
* using credentials from another identity provider. This type of authentication is typically used when a user or
Expand Down Expand Up @@ -64,12 +66,14 @@ public class OnBehalfOfCredential implements TokenCredential {
* @param identityClientOptions the options for configuring the identity client
*/
OnBehalfOfCredential(String clientId, String tenantId, String clientSecret, String certificatePath,
String certificatePassword, IdentityClientOptions identityClientOptions) {
String certificatePassword, Supplier<String> clientAssertionSupplier,
IdentityClientOptions identityClientOptions) {
IdentityClientBuilder builder = new IdentityClientBuilder()
.tenantId(tenantId)
.clientId(clientId)
.clientSecret(clientSecret)
.certificatePath(certificatePath)
.clientAssertionSupplier(clientAssertionSupplier)
.certificatePassword(certificatePassword)
.identityClientOptions(identityClientOptions);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import com.azure.core.util.logging.ClientLogger;
import com.azure.identity.implementation.util.ValidationUtil;

import java.util.function.Supplier;

/**
* Fluent credential builder for instantiating a {@link OnBehalfOfCredential}.
*
Expand Down Expand Up @@ -47,6 +49,8 @@ public class OnBehalfOfCredentialBuilder extends AadCredentialBuilderBase<OnBeha
private String clientSecret;
private String clientCertificatePath;
private String clientCertificatePassword;
private Supplier<String> clientAssertionSupplier;


/**
* Constructs an instance of OnBehalfOfCredentialBuilder.
Expand Down Expand Up @@ -136,6 +140,17 @@ public OnBehalfOfCredentialBuilder userAssertion(String userAssertion) {
return this;
}

/**
* Sets the supplier containing the logic to supply the client assertion when invoked.
*
* @param clientAssertionSupplier the supplier supplying client assertion.
* @return An updated instance of this builder.
*/
public OnBehalfOfCredentialBuilder clientAssertion(Supplier<String> clientAssertionSupplier) {
this.clientAssertionSupplier = clientAssertionSupplier;
return this;
}

/**
* Creates a new {@link OnBehalfOfCredential} with the current configurations.
*
Expand All @@ -146,19 +161,16 @@ public OnBehalfOfCredentialBuilder userAssertion(String userAssertion) {
public OnBehalfOfCredential build() {
ValidationUtil.validate(CLASS_NAME, LOGGER, "clientId", clientId, "tenantId", tenantId);

if (clientSecret == null && clientCertificatePath == null) {
throw LOGGER.logExceptionAsWarning(new IllegalArgumentException("At least client secret or certificate "
+ "path should provided in OnBehalfOfCredentialBuilder. Only one of them should "
+ "be provided."));
}

if (clientCertificatePath != null && clientSecret != null) {
throw LOGGER.logExceptionAsWarning(new IllegalArgumentException("Both client secret and certificate "
+ "path are provided in OnBehalfCredentialBuilder. Only one of them should "
+ "be provided."));
if ((clientSecret == null && clientCertificatePath == null && clientAssertionSupplier == null)
|| (clientSecret != null && clientCertificatePath != null)
|| (clientSecret != null && clientAssertionSupplier != null)
|| (clientCertificatePath != null && clientAssertionSupplier != null)) {
throw LOGGER.logExceptionAsWarning(new IllegalArgumentException("Exactly one of client secret, "
+ "client certificate path, or client assertion supplier must be provided "
+ "in OnBehalfOfCredentialBuilder."));
}

return new OnBehalfOfCredential(clientId, tenantId, clientSecret, clientCertificatePath,
clientCertificatePassword, identityClientOptions);
clientCertificatePassword, clientAssertionSupplier, identityClientOptions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -425,11 +425,10 @@ public Mono<AccessToken> authenticateWithAzurePowerShell(TokenRequestContext req
ValidationUtil.validateTenantIdCharacterRange(tenantId, LOGGER);
List<CredentialUnavailableException> exceptions = new ArrayList<>(2);

PowershellManager defaultPowerShellManager = new PowershellManager(Platform.isWindows()
? DEFAULT_WINDOWS_PS_EXECUTABLE : DEFAULT_LINUX_PS_EXECUTABLE);
PowershellManager defaultPowerShellManager = new PowershellManager(false);

PowershellManager legacyPowerShellManager = Platform.isWindows()
? new PowershellManager(LEGACY_WINDOWS_PS_EXECUTABLE) : null;
? new PowershellManager(true) : null;

List<PowershellManager> powershellManagers = new ArrayList<>(2);
powershellManagers.add(defaultPowerShellManager);
Expand Down Expand Up @@ -486,9 +485,9 @@ private Mono<AccessToken> getAccessTokenFromPowerShell(TokenRequestContext reque
} catch (IllegalArgumentException ex) {
throw LOGGER.logExceptionAsError(ex);
}
return Mono.using(() -> powershellManager, manager -> manager.initSession().flatMap(m -> {
return Mono.defer(() -> {
String azAccountsCommand = "Import-Module Az.Accounts -MinimumVersion 2.2.0 -PassThru";
return m.runCommand(azAccountsCommand).flatMap(output -> {
return powershellManager.runCommand(azAccountsCommand).flatMap(output -> {
if (output.contains("The specified module 'Az.Accounts' with version '2.2.0' was not loaded "
+ "because no valid module file")) {
return Mono.error(LoggingUtil.logCredentialUnavailableException(LOGGER, options,
Expand All @@ -504,7 +503,7 @@ private Mono<AccessToken> getAccessTokenFromPowerShell(TokenRequestContext reque
LOGGER.verbose("Azure Powershell Authentication => Executing the command `{}` in Azure "
+ "Powershell to retrieve the Access Token.", command);

return m.runCommand(command).flatMap(out -> {
return powershellManager.runCommand(command).flatMap(out -> {
if (out.contains("Run Connect-AzAccount to login")) {
return Mono.error(LoggingUtil.logCredentialUnavailableException(LOGGER, options,
new CredentialUnavailableException(
Expand All @@ -527,7 +526,7 @@ private Mono<AccessToken> getAccessTokenFromPowerShell(TokenRequestContext reque
}
});
});
}), PowershellManager::close);
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,173 +6,52 @@
import com.azure.core.util.logging.ClientLogger;
import com.azure.identity.CredentialUnavailableException;
import com.sun.jna.Platform;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.regex.Pattern;

public class PowershellManager {

private static final ClientLogger LOGGER = new ClientLogger(PowershellManager.class);
public static final Pattern PS_RESPONSE_PATTERN = Pattern.compile("\\s+$");
private Process process;
private PrintWriter commandWriter;
private boolean closed;
private int waitPause = 1000;
private long maxWait = 10000L;
private static final String DEFAULT_WINDOWS_POWERSHELL_PATH = "pwsh.exe";
private static final String LEGACY_WINDOWS_POWERSHELL_PATH = "powershell.exe";
private static final String DEFAULT_NIX_POWERSHELL_PATH = "pwsh";
private final String powershellPath;
private ExecutorService executorService;


public PowershellManager(String powershellPath) {
this.powershellPath = powershellPath;
}

public PowershellManager(String powershellPath, ExecutorService executorService) {
this.powershellPath = powershellPath;
this.executorService = executorService;
}

public Mono<PowershellManager> initSession() {

ProcessBuilder pb;
public PowershellManager(boolean useLegacyPowerShell) {
if (Platform.isWindows()) {
pb = new ProcessBuilder("cmd.exe", "/c", "chcp", "65001", ">", "NUL", "&",
powershellPath, "-ExecutionPolicy", "Bypass", "-NoExit", "-NoProfile", "-Command", "-");
this.powershellPath = useLegacyPowerShell ? LEGACY_WINDOWS_POWERSHELL_PATH : DEFAULT_WINDOWS_POWERSHELL_PATH;
} else {
pb = new ProcessBuilder(powershellPath, "-nologo", "-noexit", "-Command", "-");
this.powershellPath = DEFAULT_NIX_POWERSHELL_PATH;
}

pb.redirectErrorStream(true);


Supplier<PowershellManager> supplier = () -> {
try {
this.process = pb.start();
this.commandWriter = new PrintWriter(
new OutputStreamWriter(new BufferedOutputStream(process.getOutputStream()), StandardCharsets.UTF_8),
true);
if (this.process.waitFor(4L, TimeUnit.SECONDS) && !this.process.isAlive()) {
throw new CredentialUnavailableException("Unable to execute PowerShell."
+ " Please make sure that it is installed in your system.");
}
this.closed = false;
} catch (InterruptedException | IOException e) {
throw new CredentialUnavailableException("Unable to execute PowerShell. "
+ "Please make sure that it is installed in your system", e);
}
return this;
};
return executorService != null ? Mono.fromFuture(CompletableFuture.supplyAsync(supplier, executorService))
: Mono.fromFuture(CompletableFuture.supplyAsync(supplier));
}


public Mono<String> runCommand(String command) {
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream(),
StandardCharsets.UTF_8));
StringBuilder powerShellOutput = new StringBuilder();
commandWriter.println(command);
return canRead(reader)
.flatMap(b -> {
if (b) {
return readData(reader, powerShellOutput)
.flatMap(ignored -> Mono.just(PS_RESPONSE_PATTERN.matcher(powerShellOutput.toString())
.replaceAll("")));
} else {
return Mono.error(new CredentialUnavailableException("Error reading data from reader"));
}
});
}

private Mono<Boolean> readData(BufferedReader reader, StringBuilder powerShellOutput) {
return Mono.defer(() -> {
String line;
public Mono<String> runCommand(String input) {
return Mono.fromCallable(() -> {
try {
line = reader.readLine();
if (line != null) {
powerShellOutput.append(line).append("\r\n");
return canRead(reader).flatMap(b -> {
if (!this.closed && b) {
return Mono.empty();
}
return Mono.just(true);
});
} else {
return Mono.just(true);
}
} catch (IOException e) {
return Mono.error(
new CredentialUnavailableException("Powershell reader not ready for reading", e));
}
}).repeatWhenEmpty((Flux<Long> longFlux) -> longFlux.concatMap(ignored -> Flux.just(true)));
}

private Mono<Boolean> canRead(BufferedReader reader) {
Supplier<Boolean> supplier = () -> {
int pause = 62;
int maxPause = Platform.isMac() ? this.waitPause : 500;
while (true) {
try {
if (!reader.ready()) {
if (pause > maxPause) {
return false;
}
pause *= 2;
Thread.sleep((long) pause);
} else {
break;
String[] command = Platform.isWindows()
? new String[]{powershellPath, "-Command", input}
: new String[]{"/bin/bash", "-c", String.format("%s -Command '%s'", powershellPath, input)};


ProcessBuilder processBuilder = new ProcessBuilder(command);
processBuilder.redirectErrorStream(true);
Process process = processBuilder.start();
process.waitFor(10000L, TimeUnit.MILLISECONDS);
// Read output
StringBuilder output = new StringBuilder();
try (BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()))) {
String line;
while ((line = reader.readLine()) != null) {
output.append(line).append(System.lineSeparator());
}

} catch (IOException | InterruptedException e) {
throw new CredentialUnavailableException("Powershell reader not ready for reading", e);
}
return output.toString();
} catch (IOException | InterruptedException e) {
throw LOGGER.logExceptionAsError(new CredentialUnavailableException("PowerShell command failure.", e));
}
return true;
};
return executorService != null ? Mono.fromFuture(CompletableFuture.supplyAsync(supplier, executorService))
: Mono.fromFuture(CompletableFuture.supplyAsync(supplier));
}

public Mono<Boolean> close() {
if (!this.closed && this.process != null) {
Supplier<Boolean> supplier = () -> {
this.commandWriter.println("exit");
try {
this.process.waitFor(maxWait, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
LOGGER.logExceptionAsError(new RuntimeException("PowerShell process encountered unexpected"
+ " error when closing.", e));
} finally {
this.commandWriter.close();

try {
if (process.isAlive()) {
process.getInputStream().close();
}
} catch (IOException ex) {
LOGGER.logExceptionAsError(new RuntimeException("PowerShell stream encountered unexpected"
+ " error when closing.", ex));
}
this.closed = true;
}
return this.closed;
};
return executorService != null ? Mono.fromFuture(CompletableFuture.supplyAsync(supplier, executorService))
: Mono.fromFuture(CompletableFuture.supplyAsync(supplier));
} else {
return Mono.just(true);
}
});
}
}
Loading