Skip to content

Commit

Permalink
Implemented Shared Key authentication for Batch Track 2 (#34871)
Browse files Browse the repository at this point in the history
* Implemented Shared Key auth support for Batch track 2

Implemented Shared Key authentication for Batch Track 2

Readding Date Header Policy and fixed Shared Key logic

* Readded generated annotation on token credential
  • Loading branch information
NickKouds authored and skapur12 committed Apr 18, 2024
1 parent 564bbe0 commit ba0a10b
Show file tree
Hide file tree
Showing 9 changed files with 576 additions and 215 deletions.
5 changes: 5 additions & 0 deletions sdk/batch/azure-compute-batch/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,10 @@
<version>4.13.2</version> <!-- {x-version-update;junit:junit;external_dependency} -->
<scope>test</scope>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
<version>1.15</version> <!-- {x-version-update;commons-codec:commons-codec;external_dependency} -->
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// Code generated by Microsoft (R) AutoRest Code Generator.
package com.azure.compute.batch;

import com.azure.compute.batch.auth.BatchSharedKeyCredentials;
import com.azure.compute.batch.auth.BatchSharedKeyCredentialsPolicy;
import com.azure.compute.batch.implementation.BatchServiceClientImpl;
import com.azure.core.annotation.Generated;
import com.azure.core.annotation.ServiceClientBuilder;
Expand Down Expand Up @@ -186,6 +188,14 @@ public BatchServiceClientBuilder credential(TokenCredential tokenCredential) {
return this;
}

private BatchSharedKeyCredentials batchSharedKeyCred;

public BatchServiceClientBuilder credential(BatchSharedKeyCredentials batchSharedKeyCred) {
this.batchSharedKeyCred = Objects.requireNonNull(batchSharedKeyCred, "'batchSharedKeyCred' cannot be null.");
this.tokenCredential = null;
return this;
}

/*
* The service endpoint
*/
Expand Down Expand Up @@ -249,7 +259,6 @@ private BatchServiceClientImpl buildInnerClient() {
return client;
}

@Generated
private HttpPipeline createHttpPipeline() {
Configuration buildConfiguration =
(configuration == null) ? Configuration.getGlobalConfiguration() : configuration;
Expand Down Expand Up @@ -277,6 +286,9 @@ private HttpPipeline createHttpPipeline() {
if (tokenCredential != null) {
policies.add(new BearerTokenAuthenticationPolicy(tokenCredential, DEFAULT_SCOPES));
}
else if (batchSharedKeyCred != null) {
policies.add(new BatchSharedKeyCredentialsPolicy(batchSharedKeyCred));
}
this.pipelinePolicies.stream()
.filter(p -> p.getPipelinePosition() == HttpPipelinePosition.PER_RETRY)
.forEach(p -> policies.add(p));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.azure.compute.batch.auth;

public final class BatchSharedKeyCredentials {
private String accountName;

private String keyValue;

private String baseUrl;

/**
* Gets the Batch account name.
*
* @return The Batch account name.
*/
public String accountName() {
return accountName;
}

/**
* Gets the Base64 encoded account access key.
*
* @return The Base64 encoded account access key.
*/
public String keyValue() {
return keyValue;
}

/**
* Initializes a new instance of the {@link BatchSharedKeyCredentials} class with the specified Batch service endpoint, account name, and access key.
*
* @param baseUrl The Batch service endpoint.
* @param accountName The Batch account name.
* @param keyValue The Batch access key.
*/
public BatchSharedKeyCredentials(String baseUrl, String accountName, String keyValue) {

if (baseUrl == null) {
throw new IllegalArgumentException("Parameter baseUrl is required and cannot be null.");
}
if (accountName == null) {
throw new IllegalArgumentException("Parameter accountName is required and cannot be null.");
}
if (keyValue == null) {
throw new IllegalArgumentException("Parameter keyValue is required and cannot be null.");
}

this.baseUrl = baseUrl;
this.accountName = accountName;
this.keyValue = keyValue;
}

public String baseUrl() {
return this.baseUrl;
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package com.azure.compute.batch.auth;

import com.azure.core.http.*;
import com.azure.core.http.policy.HttpPipelinePolicy;
import com.azure.core.util.DateTimeRfc1123;
import com.azure.core.util.Header;
import org.apache.commons.codec.binary.Base64;
import reactor.core.publisher.Mono;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.io.IOException;
import java.net.URLDecoder;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.util.*;

import static java.time.OffsetDateTime.now;

public final class BatchSharedKeyCredentialsPolicy implements HttpPipelinePolicy {
private final BatchSharedKeyCredentials batchSharedKeyCred;
private Mac hmacSha256;

/**
* Creates a SharedKey pipeline policy that adds the SharedKey into the request's authorization header.
*
* @param credential the SharedKey credential used to create the policy.
*/
public BatchSharedKeyCredentialsPolicy(BatchSharedKeyCredentials credential) {
this.batchSharedKeyCred = credential;
}

/**
* @return the {@link BatchSharedKeyCredentials} linked to the policy.
*/

private String headerValue(HttpRequest request, HttpHeaderName headerName) {
HttpHeaders headers = request.getHeaders();
Header header = headers.get(headerName);
if (header == null) {
return "";
}

return header.getValue();
}

private synchronized String sign(String stringToSign) {
try {
// Encoding the Signature
// Signature=Base64(HMAC-SHA256(UTF8(StringToSign)))
byte[] digest = getHmac256().doFinal(stringToSign.getBytes("UTF-8"));
return Base64.encodeBase64String(digest);
} catch (Exception e) {
throw new IllegalArgumentException("accessKey", e);
}
}

private synchronized Mac getHmac256() throws NoSuchAlgorithmException, InvalidKeyException {
if (this.hmacSha256 == null) {
// Initializes the HMAC-SHA256 Mac and SecretKey.
this.hmacSha256 = Mac.getInstance("HmacSHA256");
this.hmacSha256.init(new SecretKeySpec(Base64.decodeBase64(batchSharedKeyCred.keyValue()), "HmacSHA256"));
}
return this.hmacSha256;
}

public String signHeader(HttpRequest request) throws IOException {

// Set Headers
String dateHeaderToSign = headerValue(request, HttpHeaderName.DATE);
if (request.getHeaders().get("ocp-date") == null) {
if (dateHeaderToSign == null) {
DateTimeRfc1123 rfcDate = new DateTimeRfc1123(now());
request.setHeader("ocp-date", rfcDate.toString());
dateHeaderToSign = ""; //Cannot append both ocp-date and date header values
}
}
else {
dateHeaderToSign = ""; //Cannot append both ocp-date and date header values
}

StringBuffer signature = new StringBuffer(request.getHttpMethod().toString());
signature.append("\n");
signature.append(headerValue(request, HttpHeaderName.CONTENT_ENCODING)).append("\n");
signature.append(headerValue(request, HttpHeaderName.CONTENT_LANGUAGE)).append("\n");

// Special handle content length
String contentLength = headerValue(request, HttpHeaderName.CONTENT_LENGTH);

signature.append((contentLength == null || Long.parseLong(contentLength) < 0 ? "" : contentLength)).append("\n");

signature.append(headerValue(request, HttpHeaderName.CONTENT_MD5)).append("\n");

String contentType = headerValue(request, HttpHeaderName.CONTENT_TYPE);
signature.append(contentType).append("\n");

signature.append(dateHeaderToSign).append("\n");
signature.append(headerValue(request, HttpHeaderName.IF_MODIFIED_SINCE)).append("\n");
signature.append(headerValue(request, HttpHeaderName.IF_MATCH)).append("\n");
signature.append(headerValue(request, HttpHeaderName.IF_NONE_MATCH)).append("\n");
signature.append(headerValue(request, HttpHeaderName.IF_UNMODIFIED_SINCE)).append("\n");
signature.append(headerValue(request, HttpHeaderName.RANGE)).append("\n");

ArrayList<String> customHeaders = new ArrayList<>();
for (HttpHeader name : request.getHeaders()) {
if (name.getName().toLowerCase().startsWith("ocp-")) {
customHeaders.add(name.getName().toLowerCase());
}
}

Collections.sort(customHeaders);
for (String canonicalHeader : customHeaders) {
String value = request.getHeaders().getValue(canonicalHeader);
value = value.replace('\n', ' ').replace('\r', ' ')
.replaceAll("^[ ]+", "");
signature.append(canonicalHeader).append(":").append(value).append("\n");
}

signature.append("/")
.append(batchSharedKeyCred.accountName().toLowerCase()).append("/")
.append(request.getUrl().getPath().replaceAll("^[/]+", ""));

String query = request.getUrl().getQuery();

if (query != null) {
Map<String, String> queryComponents = new TreeMap<>();
String[] pairs = query.split("&");
for (String pair : pairs) {
int idx = pair.indexOf("=");
String key = URLDecoder.decode(pair.substring(0, idx), "UTF-8")
.toLowerCase(Locale.US);
queryComponents.put(
key,
key + ":" + URLDecoder.decode(pair.substring(idx + 1), "UTF-8"));
}

for (Map.Entry<String, String> entry : queryComponents.entrySet()) {
signature.append("\n").append(entry.getValue());
}
}

String signedSignature = sign(signature.toString());
String authorization = "SharedKey " + batchSharedKeyCred.accountName()
+ ":" + signedSignature;

return authorization;
}

@Override
public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) {
try {
String authorizationValue = this.signHeader(context.getHttpRequest());
context.getHttpRequest().setHeader("Authorization", authorizationValue);
} catch (IOException e) {
throw new RuntimeException(e);
}
return next.process();
}
}
Loading

0 comments on commit ba0a10b

Please sign in to comment.