Skip to content

Commit

Permalink
Support passing groups in OAuth access token claim
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-walkiewicz authored and kokosing committed Jan 24, 2022
1 parent cf57470 commit 8b8b0be
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.server.security.oauth2;

import com.google.common.collect.ImmutableSet;
import io.trino.server.security.AbstractBearerAuthenticator;
import io.trino.server.security.AuthenticationException;
import io.trino.server.security.UserMapping;
Expand All @@ -24,6 +25,7 @@
import javax.ws.rs.container.ContainerRequestContext;

import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
Expand All @@ -39,13 +41,15 @@ public class OAuth2Authenticator
{
private final OAuth2Service service;
private final String principalField;
private final Optional<String> groupsField;
private final UserMapping userMapping;

@Inject
public OAuth2Authenticator(OAuth2Service service, OAuth2Config config)
{
this.service = requireNonNull(service, "service is null");
this.principalField = config.getPrincipalField();
groupsField = requireNonNull(config.getGroupsField(), "groupsField is null");
userMapping = createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile());
}

Expand All @@ -59,9 +63,11 @@ protected Optional<Identity> createIdentity(String token)
return Optional.empty();
}
String principal = (String) claims.get().get(principalField);
return Optional.of(Identity.forUser(userMapping.mapUser(principal))
.withPrincipal(new BasicPrincipal(principal))
.build());
Identity.Builder builder = Identity.forUser(userMapping.mapUser(principal));
builder.withPrincipal(new BasicPrincipal(principal));
groupsField.flatMap(field -> Optional.ofNullable((List<String>) claims.get().get(field)))
.ifPresent(groups -> builder.withGroups(ImmutableSet.copyOf(groups)));
return Optional.of(builder.build());
}
catch (ChallengeFailedException e) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class OAuth2Config
private String clientSecret;
private Set<String> scopes = ImmutableSet.of(OPENID_SCOPE);
private String principalField = "sub";
private Optional<String> groupsField = Optional.empty();
private List<String> additionalAudiences = Collections.emptyList();
private Duration challengeTimeout = new Duration(15, TimeUnit.MINUTES);
private Optional<String> userMappingPattern = Optional.empty();
Expand Down Expand Up @@ -222,6 +223,19 @@ public OAuth2Config setPrincipalField(String principalField)
return this;
}

public Optional<String> getGroupsField()
{
return groupsField;
}

@Config("http-server.authentication.oauth2.groups-field")
@ConfigDescription("Groups field in the claim")
public OAuth2Config setGroupsField(String groupsField)
{
this.groupsField = Optional.ofNullable(groupsField);
return this;
}

@MinDuration("1ms")
@NotNull
public Duration getChallengeTimeout()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.core.Response;

import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand All @@ -50,6 +51,7 @@ public class OAuth2WebUiAuthenticationFilter
private final String principalField;
private final OAuth2Service service;
private final UserMapping userMapping;
private final Optional<String> groupsField;

@Inject
public OAuth2WebUiAuthenticationFilter(OAuth2Service service, OAuth2Config oauth2Config)
Expand All @@ -58,6 +60,7 @@ public OAuth2WebUiAuthenticationFilter(OAuth2Service service, OAuth2Config oauth
requireNonNull(oauth2Config, "oauth2Config is null");
this.userMapping = UserMapping.createUserMapping(oauth2Config.getUserMappingPattern(), oauth2Config.getUserMappingFile());
this.principalField = oauth2Config.getPrincipalField();
groupsField = requireNonNull(oauth2Config.getGroupsField(), "groupsField is null");
}

@Override
Expand Down Expand Up @@ -101,9 +104,11 @@ public void filter(ContainerRequestContext request)
return;
}
String principalName = (String) principal;
setAuthenticatedIdentity(request, Identity.forUser(userMapping.mapUser(principalName))
.withPrincipal(new BasicPrincipal(principalName))
.build());
Identity.Builder builder = Identity.forUser(userMapping.mapUser(principalName));
builder.withPrincipal(new BasicPrincipal(principalName));
groupsField.flatMap(field -> Optional.ofNullable((List<String>) claims.get().get(field)))
.ifPresent(groups -> builder.withGroups(ImmutableSet.copyOf(groups)));
setAuthenticatedIdentity(request, builder.build());
}
catch (UserMappingException e) {
sendErrorMessage(request, UNAUTHORIZED, firstNonNull(e.getMessage(), "Unauthorized"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.Resources;
Expand All @@ -38,13 +39,17 @@
import io.trino.spi.security.BasicPrincipal;
import io.trino.spi.security.Identity;
import io.trino.spi.security.SystemSecurityContext;
import okhttp3.Cookie;
import okhttp3.CookieJar;
import okhttp3.Credentials;
import okhttp3.Headers;
import okhttp3.HttpUrl;
import okhttp3.JavaNetCookieJar;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import javax.crypto.SecretKey;
Expand Down Expand Up @@ -90,7 +95,10 @@
import static io.trino.client.ProtocolHeaders.TRINO_HEADERS;
import static io.trino.metadata.MetadataManager.createTestMetadataManager;
import static io.trino.server.security.ResourceSecurity.AccessType.AUTHENTICATED_USER;
import static io.trino.server.security.ResourceSecurity.AccessType.WEB_UI;
import static io.trino.server.security.oauth2.OAuth2Service.NONCE;
import static io.trino.server.ui.FormWebUiAuthenticationFilter.UI_LOCATION;
import static io.trino.server.ui.OAuthWebUiCookie.OAUTH2_COOKIE;
import static io.trino.spi.security.AccessDeniedException.denyImpersonateUser;
import static io.trino.spi.security.AccessDeniedException.denyReadSystemInformationAccess;
import static io.trino.testing.assertions.Assert.assertEquals;
Expand Down Expand Up @@ -130,6 +138,7 @@ public class TestResourceSecurity
private static final String MANAGEMENT_PASSWORD = "management-password";
private static final String HMAC_KEY = Resources.getResource("hmac_key.txt").getPath();
private static final String JWK_KEY_ID = "test-rsa";
private static final String GROUPS_CLAIM = "groups";
private static final PrivateKey JWK_PRIVATE_KEY;
private static final ObjectMapper json = new ObjectMapper();

Expand Down Expand Up @@ -655,6 +664,86 @@ public HttpCookie getNonceCookie()
}
}

@Test(dataProvider = "groups")
public void testOAuth2Groups(Optional<Set<String>> groups)
throws Exception
{
try (TokenServer tokenServer = new TokenServer(Optional.empty());
TestingTrinoServer server = TestingTrinoServer.builder()
.setProperties(ImmutableMap.<String, String>builder()
.putAll(SECURE_PROPERTIES)
.put("web-ui.enabled", "true")
.putAll(getOAuth2Properties(tokenServer))
.put("http-server.authentication.oauth2.groups-field", GROUPS_CLAIM)
.build())
.setAdditionalModule(oauth2Module(tokenServer))
.build()) {
server.getInstance(Key.get(AccessControlManager.class)).addSystemAccessControl(TestSystemAccessControl.NO_IMPERSONATION);
HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class));

String accessToken = tokenServer.issueAccessToken(groups);
OkHttpClient clientWithOAuthToken = client.newBuilder()
.authenticator((route, response) -> response.request().newBuilder()
.header(AUTHORIZATION, "Bearer " + accessToken)
.build())
.build();

assertAuthenticationAutomatic(httpServerInfo.getHttpsUri(), clientWithOAuthToken);

try (Response response = clientWithOAuthToken.newCall(new Request.Builder()
.url(getLocation(httpServerInfo.getHttpsUri(), "/protocol/identity"))
.build())
.execute()) {
assertEquals(response.code(), SC_OK);
assertEquals(response.header("user"), TEST_USER);
assertEquals(response.header("principal"), TEST_USER);
assertEquals(response.header("groups"), groups.map(TestResource::toHeader).orElse(""));
}

OkHttpClient clientWithOAuthCookie = client.newBuilder()
.cookieJar(new CookieJar()
{
@Override
public void saveFromResponse(HttpUrl url, List<Cookie> cookies)
{
}

@Override
public List<Cookie> loadForRequest(HttpUrl url)
{
return ImmutableList.of(new Cookie.Builder()
.domain(httpServerInfo.getHttpsUri().getHost())
.path(UI_LOCATION)
.name(OAUTH2_COOKIE)
.value(accessToken)
.httpOnly()
.secure()
.build());
}
})
.build();
try (Response response = clientWithOAuthCookie.newCall(new Request.Builder()
.url(getLocation(httpServerInfo.getHttpsUri(), "/ui/api/identity"))
.build())
.execute()) {
assertEquals(response.code(), SC_OK);
assertEquals(response.header("user"), TEST_USER);
assertEquals(response.header("principal"), TEST_USER);
assertEquals(response.header("groups"), groups.map(TestResource::toHeader).orElse(""));
}
}
}

@DataProvider(name = "groups")
public static Object[][] groups()
{
return new Object[][] {
{Optional.empty()},
{Optional.of(ImmutableSet.of())},
{Optional.of(ImmutableSet.of("admin", "public"))}
};
}

private static Module oauth2Module(TokenServer tokenServer)
{
return binder -> {
Expand Down Expand Up @@ -707,7 +796,7 @@ public TokenServer(Optional<String> principalField)
this.principalField = requireNonNull(principalField, "principalField is null");
jwkServer = createTestingJwkServer();
jwkServer.start();
accessToken = issueAccessToken();
accessToken = issueAccessToken(Optional.empty());
}

@Override
Expand Down Expand Up @@ -767,7 +856,7 @@ public String getAccessToken()
return accessToken;
}

private String issueAccessToken()
public String issueAccessToken(Optional<Set<String>> groups)
{
JwtBuilder accessToken = Jwts.builder()
.signWith(JWK_PRIVATE_KEY)
Expand All @@ -781,6 +870,7 @@ private String issueAccessToken()
else {
accessToken.setSubject(TEST_USER);
}
groups.ifPresent(groupsClaim -> accessToken.claim(GROUPS_CLAIM, groupsClaim));
return accessToken.compact();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public void testDefaults()
.setScopes("openid")
.setChallengeTimeout(new Duration(15, MINUTES))
.setPrincipalField("sub")
.setGroupsField(null)
.setAdditionalAudiences(Collections.emptyList())
.setUserMappingPattern(null)
.setUserMappingFile(null));
Expand All @@ -70,6 +71,7 @@ public void testExplicitPropertyMappings()
.put("http-server.authentication.oauth2.client-secret", "consumer-secret")
.put("http-server.authentication.oauth2.scopes", "email,offline")
.put("http-server.authentication.oauth2.principal-field", "some-field")
.put("http-server.authentication.oauth2.groups-field", "groups")
.put("http-server.authentication.oauth2.additional-audiences", "test-aud1,test-aud2")
.put("http-server.authentication.oauth2.challenge-timeout", "90s")
.put("http-server.authentication.oauth2.user-mapping.pattern", "(.*)@something")
Expand All @@ -88,6 +90,7 @@ public void testExplicitPropertyMappings()
.setClientSecret("consumer-secret")
.setScopes("email, offline")
.setPrincipalField("some-field")
.setGroupsField("groups")
.setAdditionalAudiences(List.of("test-aud1", "test-aud2"))
.setChallengeTimeout(new Duration(90, SECONDS))
.setUserMappingPattern("(.*)@something")
Expand Down
2 changes: 2 additions & 0 deletions docs/src/main/sphinx/security/oauth2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ The following configuration properties are available:
for more information.
* - ``http-server.authentication.oauth2.principal-field``
- The field of the access token used for the Trino user principal. Defaults to ``sub``. Other commonly used fields include ``sAMAccountName``, ``name``, ``upn``, and ``email``.
* - ``http-server.authentication.oauth2.groups-field``
- The field of the access token used for Trino groups. The corresponding claim value must be an array.


Troubleshooting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ public void extendEnvironment(Environment.Builder builder)
.withEnv("SERVE_TLS_CERT_PATH", "/tmp/certs/hydra.pem")
.withEnv("STRATEGIES_ACCESS_TOKEN", "jwt")
.withEnv("TTL_ACCESS_TOKEN", TTL_ACCESS_TOKEN_IN_SECONDS + "s")
.withEnv("OAUTH2_ALLOWED_TOP_LEVEL_CLAIMS", "groups")
.withCommand("serve", "all")
.withCopyFileToContainer(forHostPath(configDir.getPath("cert/hydra.pem")), "/tmp/certs/hydra.pem")
.waitingFor(new WaitAllStrategy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ http-server.authentication.oauth2.jwks-url=https://hydra:4444/.well-known/jwks.j
http-server.authentication.oauth2.client-id=trinodb_client_id
http-server.authentication.oauth2.client-secret=trinodb_client_secret
http-server.authentication.oauth2.user-mapping.pattern=(.*)(@.*)?
http-server.authentication.oauth2.groups-field=groups
oauth2-jwk.http-client.trust-store-path=/docker/presto-product-tests/conf/presto/etc/hydra.pem
oauth2-jwk.http-client.http-proxy=proxy:8888
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ http-server.authentication.oauth2.jwks-url=https://hydra:4444/.well-known/jwks.j
http-server.authentication.oauth2.client-id=trinodb_client_id
http-server.authentication.oauth2.client-secret=trinodb_client_secret
http-server.authentication.oauth2.user-mapping.pattern=(.*)(@.*)?
http-server.authentication.oauth2.groups-field=groups
oauth2-jwk.http-client.trust-store-path=/docker/presto-product-tests/conf/presto/etc/cert/truststore.jks
oauth2-jwk.http-client.trust-store-password=123456
oauth2-jwk.http-client.http-proxy=proxy:8888
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ http-server.authentication.oauth2.jwks-url=https://hydra:4444/.well-known/jwks.j
http-server.authentication.oauth2.client-id=trinodb_client_id
http-server.authentication.oauth2.client-secret=trinodb_client_secret
http-server.authentication.oauth2.user-mapping.pattern=(.*)(@.*)?
http-server.authentication.oauth2.groups-field=groups
oauth2-jwk.http-client.trust-store-path=/docker/presto-product-tests/conf/presto/etc/hydra.pem
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.tests.product.jdbc;

import com.google.common.collect.ImmutableList;
import com.google.inject.Inject;
import com.google.inject.name.Named;
import io.trino.jdbc.TestingRedirectHandlerInjector;
Expand Down Expand Up @@ -42,6 +43,7 @@
import java.util.Properties;

import static com.google.common.base.Preconditions.checkState;
import static io.trino.tempto.assertions.QueryAssert.Row.row;
import static io.trino.tempto.assertions.QueryAssert.assertThat;
import static io.trino.tempto.query.QueryResult.forResultSet;
import static io.trino.tests.product.TestGroups.OAUTH2;
Expand Down Expand Up @@ -126,6 +128,20 @@ public void shouldAuthenticateAfterTokenExpires()
}
}

@Test(groups = {OAUTH2, PROFILE_SPECIFIC_TESTS})
public void shouldReturnGroups()
throws SQLException
{
prepareHandler();
Properties properties = new Properties();
properties.setProperty("user", "test");
try (Connection connection = DriverManager.getConnection(jdbcUrl, properties);
PreparedStatement statement = connection.prepareStatement("SELECT array_sort(current_groups())");
ResultSet rs = statement.executeQuery()) {
assertThat(forResultSet(rs)).containsOnly(row(ImmutableList.of("admin", "public")));
}
}

private void prepareHandler()
{
TestingRedirectHandlerInjector.setRedirectHandler(uri -> {
Expand Down

0 comments on commit 8b8b0be

Please sign in to comment.