Skip to content

Commit

Permalink
[Backport 2.x] Service accounts and on-behalf-of authentication in 2.x (
Browse files Browse the repository at this point in the history
#11052)

* Implement on behalf of token passing for extensions (#8679)

* Provide service accounts tokens to extensions (#9618)

This change adds a new transport action which passes the extension a string representation of its service account auth token. This token is created by the TokenManager interface implementation. The token is expected to be an encoded basic auth credential string which can be used by the extension to interact with its own system index.

* Cherry pick #10614 and #10664

Signed-off-by: Stephen Crawford <[email protected]>
Signed-off-by: Stephen Crawford <[email protected]>
Signed-off-by: Ryan Liang <[email protected]>
Signed-off-by: Peter Nied <[email protected]>
Co-authored-by: Stephen Crawford <[email protected]>
Co-authored-by: Peter Nied <[email protected]>
Co-authored-by: Owais Kazi <[email protected]>
Co-authored-by: Peter Nied <[email protected]>
  • Loading branch information
5 people authored Nov 2, 2023
1 parent 0d30590 commit f1df1cd
Show file tree
Hide file tree
Showing 22 changed files with 382 additions and 98 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Update the indexRandom function to create more segments for concurrent search tests ([10247](https://github.com/opensearch-project/OpenSearch/pull/10247))
- Add support for query profiler with concurrent aggregation ([#9248](https://github.com/opensearch-project/OpenSearch/pull/9248))
- Introduce ConcurrentQueryProfiler to profile query using concurrent segment search path and support concurrency during rewrite and create weight ([10352](https://github.com/opensearch-project/OpenSearch/pull/10352))
- Implement on behalf of token passing for extensions ([#8679](https://github.com/opensearch-project/OpenSearch/pull/8679))
- Provide service accounts tokens to extensions ([#9618](https://github.com/opensearch-project/OpenSearch/pull/9618))

### Dependencies
- Bumps jetty version to 9.4.52.v20230823 to fix GMS-2023-1857 ([#9822](https://github.com/opensearch-project/OpenSearch/pull/9822))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.UsernamePasswordToken;
import org.opensearch.common.Randomness;
import org.opensearch.identity.IdentityService;
import org.opensearch.identity.Subject;
import org.opensearch.identity.tokens.AuthToken;
import org.opensearch.identity.tokens.BasicAuthToken;
import org.opensearch.identity.tokens.OnBehalfOfClaims;
import org.opensearch.identity.tokens.TokenManager;

import java.util.Arrays;
Expand Down Expand Up @@ -54,15 +55,16 @@ public Optional<AuthenticationToken> translateAuthToken(org.opensearch.identity.
final BasicAuthToken basicAuthToken = (BasicAuthToken) authenticationToken;
return Optional.of(new UsernamePasswordToken(basicAuthToken.getUser(), basicAuthToken.getPassword()));
}

return Optional.empty();
}

@Override
public AuthToken issueToken(String audience) {
public AuthToken issueOnBehalfOfToken(Subject subject, OnBehalfOfClaims claims) {

String password = generatePassword();
final byte[] rawEncoded = Base64.getEncoder().encode((audience + ":" + password).getBytes(UTF_8));
// Make a new ShiroSubject audience as name
final byte[] rawEncoded = Base64.getUrlEncoder().encode((claims.getAudience() + ":" + password).getBytes(UTF_8));

final String usernamePassword = new String(rawEncoded, UTF_8);
final String header = "Basic " + usernamePassword;
BasicAuthToken token = new BasicAuthToken(header);
Expand All @@ -71,13 +73,17 @@ public AuthToken issueToken(String audience) {
return token;
}

public boolean validateToken(AuthToken token) {
if (token instanceof BasicAuthToken) {
final BasicAuthToken basicAuthToken = (BasicAuthToken) token;
return basicAuthToken.getUser().equals(SecurityUtils.getSubject().toString())
&& basicAuthToken.getPassword().equals(shiroTokenPasswordMap.get(basicAuthToken));
}
return false;
@Override
public AuthToken issueServiceAccountToken(String audience) {

String password = generatePassword();
final byte[] rawEncoded = Base64.getUrlEncoder().withoutPadding().encode((audience + ":" + password).getBytes(UTF_8)); // Make a new
final String usernamePassword = new String(rawEncoded, UTF_8);
final String header = "Basic " + usernamePassword;

BasicAuthToken token = new BasicAuthToken(header);
shiroTokenPasswordMap.put(token, password);
return token;
}

public String getTokenInfo(AuthToken token) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@

import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.UsernamePasswordToken;
import org.opensearch.identity.Subject;
import org.opensearch.identity.noop.NoopSubject;
import org.opensearch.identity.noop.NoopTokenManager;
import org.opensearch.identity.tokens.AuthToken;
import org.opensearch.identity.tokens.BasicAuthToken;
import org.opensearch.identity.tokens.BearerAuthToken;
import org.opensearch.identity.tokens.OnBehalfOfClaims;
import org.opensearch.test.OpenSearchTestCase;
import org.junit.Before;

Expand All @@ -34,16 +37,15 @@
public class AuthTokenHandlerTests extends OpenSearchTestCase {

private ShiroTokenManager shiroAuthTokenHandler;
private NoopTokenManager noopTokenManager;

@Before
public void testSetup() {
shiroAuthTokenHandler = new ShiroTokenManager();
noopTokenManager = new NoopTokenManager();
}

public void testShouldExtractBasicAuthTokenSuccessfully() {
final BasicAuthToken authToken = new BasicAuthToken("Basic YWRtaW46YWRtaW4="); // admin:admin
assertEquals(authToken.asAuthHeaderValue(), "YWRtaW46YWRtaW4=");

final AuthenticationToken translatedToken = shiroAuthTokenHandler.translateAuthToken(authToken).get();
assertThat(translatedToken, is(instanceOf(UsernamePasswordToken.class)));
Expand Down Expand Up @@ -98,18 +100,13 @@ public void testShouldFailGetTokenInfo() {
assertThrows(UnsupportedAuthenticationToken.class, () -> shiroAuthTokenHandler.getTokenInfo(bearerAuthToken));
}

public void testShouldFailValidateToken() {
final BearerAuthToken bearerAuthToken = new BearerAuthToken("header.payload.signature");
assertFalse(shiroAuthTokenHandler.validateToken(bearerAuthToken));
}

public void testShoudPassMapLookupWithToken() {
final BasicAuthToken authToken = new BasicAuthToken("Basic dGVzdDp0ZTpzdA==");
shiroAuthTokenHandler.getShiroTokenPasswordMap().put(authToken, "te:st");
assertTrue(authToken.getPassword().equals(shiroAuthTokenHandler.getShiroTokenPasswordMap().get(authToken)));
}

public void testShouldPassThrougbResetToken(AuthToken token) {
public void testShouldPassThroughResetToken() {
final BearerAuthToken bearerAuthToken = new BearerAuthToken("header.payload.signature");
shiroAuthTokenHandler.resetToken(bearerAuthToken);
}
Expand All @@ -124,6 +121,7 @@ public void testVerifyBearerTokenObject() {
assertEquals(testGoodToken.getPayload(), "payload");
assertEquals(testGoodToken.getSignature(), "signature");
assertEquals(testGoodToken.toString(), "Bearer auth token with header=header, payload=payload, signature=signature");
assertEquals(testGoodToken.asAuthHeaderValue(), "header.payload.signature");
}

public void testGeneratedPasswordContents() {
Expand All @@ -147,4 +145,35 @@ public void testGeneratedPasswordContents() {
validator.validate(data);
}

public void testIssueOnBehalfOfTokenFromClaims() {
Subject subject = new NoopSubject();
OnBehalfOfClaims claims = new OnBehalfOfClaims("test", "test");
BasicAuthToken authToken = (BasicAuthToken) shiroAuthTokenHandler.issueOnBehalfOfToken(subject, claims);
assertTrue(authToken instanceof BasicAuthToken);
UsernamePasswordToken translatedToken = (UsernamePasswordToken) shiroAuthTokenHandler.translateAuthToken(authToken).get();
assertEquals(authToken.getPassword(), new String(translatedToken.getPassword()));
assertTrue(shiroAuthTokenHandler.getShiroTokenPasswordMap().containsKey(authToken));
assertEquals(shiroAuthTokenHandler.getShiroTokenPasswordMap().get(authToken), new String(translatedToken.getPassword()));
}

public void testTokenNoopIssuance() {
NoopTokenManager tokenManager = new NoopTokenManager();
OnBehalfOfClaims claims = new OnBehalfOfClaims("test", "test");
Subject subject = new NoopSubject();
AuthToken token = tokenManager.issueOnBehalfOfToken(subject, claims);
assertTrue(token instanceof AuthToken);
AuthToken serviceAccountToken = tokenManager.issueServiceAccountToken("test");
assertTrue(serviceAccountToken instanceof AuthToken);
assertEquals(serviceAccountToken.asAuthHeaderValue(), "noopToken");
}

public void testShouldSucceedIssueServiceAccountToken() {
String audience = "testExtensionName";
BasicAuthToken authToken = (BasicAuthToken) shiroAuthTokenHandler.issueServiceAccountToken(audience);
assertTrue(authToken instanceof BasicAuthToken);
UsernamePasswordToken translatedToken = (UsernamePasswordToken) shiroAuthTokenHandler.translateAuthToken(authToken).get();
assertEquals(authToken.getPassword(), new String(translatedToken.getPassword()));
assertTrue(shiroAuthTokenHandler.getShiroTokenPasswordMap().containsKey(authToken));
assertEquals(shiroAuthTokenHandler.getShiroTokenPasswordMap().get(authToken), new String(translatedToken.getPassword()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,27 @@
public class InitializeExtensionRequest extends TransportRequest {
private final DiscoveryNode sourceNode;
private final DiscoveryExtensionNode extension;
private final String serviceAccountHeader;

public InitializeExtensionRequest(DiscoveryNode sourceNode, DiscoveryExtensionNode extension) {
public InitializeExtensionRequest(DiscoveryNode sourceNode, DiscoveryExtensionNode extension, String serviceAccountHeader) {
this.sourceNode = sourceNode;
this.extension = extension;
this.serviceAccountHeader = serviceAccountHeader;
}

public InitializeExtensionRequest(StreamInput in) throws IOException {
super(in);
sourceNode = new DiscoveryNode(in);
extension = new DiscoveryExtensionNode(in);
serviceAccountHeader = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
sourceNode.writeTo(out);
extension.writeTo(out);
out.writeString(serviceAccountHeader);
}

public DiscoveryNode getSourceNode() {
Expand All @@ -52,6 +56,10 @@ public DiscoveryExtensionNode getExtension() {
return extension;
}

public String getServiceAccountHeader() {
return serviceAccountHeader;
}

@Override
public String toString() {
return "InitializeExtensionsRequest{" + "sourceNode=" + sourceNode + ", extension=" + extension + '}';
Expand All @@ -62,7 +70,9 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InitializeExtensionRequest that = (InitializeExtensionRequest) o;
return Objects.equals(sourceNode, that.sourceNode) && Objects.equals(extension, that.extension);
return Objects.equals(sourceNode, that.sourceNode)
&& Objects.equals(extension, that.extension)
&& Objects.equals(serviceAccountHeader, that.getServiceAccountHeader());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.opensearch.extensions.rest.RestActionsRequestHandler;
import org.opensearch.extensions.settings.CustomSettingsRequestHandler;
import org.opensearch.extensions.settings.RegisterCustomSettingsRequest;
import org.opensearch.identity.IdentityService;
import org.opensearch.identity.tokens.AuthToken;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.ConnectTransportException;
import org.opensearch.transport.TransportException;
Expand Down Expand Up @@ -100,14 +102,15 @@ public static enum OpenSearchRequestType {
private Settings environmentSettings;
private AddSettingsUpdateConsumerRequestHandler addSettingsUpdateConsumerRequestHandler;
private NodeClient client;
private IdentityService identityService;

/**
* Instantiate a new ExtensionsManager object to handle requests and responses from extensions. This is called during Node bootstrap.
*
* @param additionalSettings Additional settings to read in from extension initialization request
* @throws IOException If the extensions discovery file is not properly retrieved.
*/
public ExtensionsManager(Set<Setting<?>> additionalSettings) throws IOException {
public ExtensionsManager(Set<Setting<?>> additionalSettings, IdentityService identityService) throws IOException {
logger.info("ExtensionsManager initialized");
this.initializedExtensions = new HashMap<String, DiscoveryExtensionNode>();
this.extensionIdMap = new HashMap<String, DiscoveryExtensionNode>();
Expand All @@ -122,6 +125,7 @@ public ExtensionsManager(Set<Setting<?>> additionalSettings) throws IOException
}
this.client = null;
this.extensionTransportActionsHandler = null;
this.identityService = identityService;
}

/**
Expand All @@ -141,9 +145,15 @@ public void initializeServicesAndRestHandler(
TransportService transportService,
ClusterService clusterService,
Settings initialEnvironmentSettings,
NodeClient client
NodeClient client,
IdentityService identityService
) {
this.restActionsRequestHandler = new RestActionsRequestHandler(actionModule.getRestController(), extensionIdMap, transportService);
this.restActionsRequestHandler = new RestActionsRequestHandler(
actionModule.getRestController(),
extensionIdMap,
transportService,
identityService
);
this.customSettingsRequestHandler = new CustomSettingsRequestHandler(settingsModule);
this.transportService = transportService;
this.clusterService = clusterService;
Expand Down Expand Up @@ -399,7 +409,7 @@ protected void doRun() throws Exception {
transportService.sendRequest(
extensionNode,
REQUEST_EXTENSION_ACTION_NAME,
new InitializeExtensionRequest(transportService.getLocalNode(), extensionNode),
new InitializeExtensionRequest(transportService.getLocalNode(), extensionNode, issueServiceAccount(extensionNode)),
initializeExtensionResponseHandler
);
}
Expand Down Expand Up @@ -442,6 +452,15 @@ TransportResponse handleExtensionRequest(ExtensionRequest extensionRequest) thro
}
}

/**
* A helper method called during initialization that issues a service accounts to extensions
* @param extension The extension to be issued a service account
*/
private String issueServiceAccount(DiscoveryExtensionNode extension) {
AuthToken serviceAccountToken = identityService.getTokenManager().issueServiceAccountToken(extension.getId());
return serviceAccountToken.asAuthHeaderValue();
}

static String getRequestExtensionActionName() {
return REQUEST_EXTENSION_ACTION_NAME;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import org.opensearch.extensions.action.ExtensionActionRequest;
import org.opensearch.extensions.action.ExtensionActionResponse;
import org.opensearch.extensions.action.RemoteExtensionActionResponse;
import org.opensearch.identity.IdentityService;
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.util.List;
import java.util.Optional;
import java.util.Set;

Expand All @@ -30,7 +32,7 @@
public class NoopExtensionsManager extends ExtensionsManager {

public NoopExtensionsManager() throws IOException {
super(Set.of());
super(Set.of(), new IdentityService(Settings.EMPTY, List.of()));
}

@Override
Expand All @@ -40,7 +42,8 @@ public void initializeServicesAndRestHandler(
TransportService transportService,
ClusterService clusterService,
Settings initialEnvironmentSettings,
NodeClient client
NodeClient client,
IdentityService identityService
) {
// no-op
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import java.util.Objects;
import java.util.Set;

import static java.util.Objects.requireNonNull;

/**
* Request to execute REST actions on extension node.
* This contains necessary portions of a {@link RestRequest} object, but does not pass the full request for security concerns.
Expand Down Expand Up @@ -86,7 +88,7 @@ public ExtensionRestRequest(
this.headers = headers;
this.mediaType = mediaType;
this.content = content;
this.principalIdentifierToken = principalIdentifier;
this.principalIdentifierToken = requireNonNull(principalIdentifier);
this.httpVersion = httpVersion;
}

Expand Down Expand Up @@ -280,7 +282,7 @@ public boolean isContentConsumed() {
}

/**
* Gets a parser for the contents of this request if there is content and an xContentType.
* Gets a parser for the contents of this request if there is content, an xContentType, and a principal identifier.
*
* @param xContentRegistry The extension's xContentRegistry
* @return A parser for the given content and content type.
Expand All @@ -291,6 +293,9 @@ public final XContentParser contentParser(NamedXContentRegistry xContentRegistry
if (!hasContent() || getXContentType() == null) {
throw new OpenSearchParseException("There is no request body or the ContentType is invalid.");
}
if (getRequestIssuerIdentity() == null) {
throw new OpenSearchParseException("There is no request body or the requester identity is invalid.");
}
return getXContentType().xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, content.streamInput());
}

Expand Down
Loading

0 comments on commit f1df1cd

Please sign in to comment.