Skip to content

Commit

Permalink
[trinodb#4] allow presto jdbc user to access trino server correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
向阿鲲 authored and fengguangyuan committed Mar 9, 2022
1 parent aafd909 commit cf039a4
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,26 @@ public SessionContext createSessionContext(
catch (ProtocolDetectionException e) {
throw badRequest(e.getMessage());
}
Optional<String> catalog = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestCatalog())));
Optional<String> schema = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestSchema())));
String trinoCatalog = trimEmptyToNull(headers.getFirst(protocolHeaders.requestCatalog()));
Optional<String> catalog = Optional.ofNullable(
trinoCatalog == null ? trimEmptyToNull(headers.getFirst("X-Presto-Catalog")) : trinoCatalog);
String trinoSchema = trimEmptyToNull(headers.getFirst(protocolHeaders.requestSchema()));
Optional<String> schema = Optional.ofNullable(
trinoSchema == null ? trimEmptyToNull(headers.getFirst("X-Presto-Schema")) : trinoSchema);
Optional<String> path = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestPath())));
assertRequest((catalog.isPresent()) || (schema.isEmpty()), "Schema is set but catalog is not");

requireNonNull(authenticatedIdentity, "authenticatedIdentity is null");
Identity identity = buildSessionIdentity(authenticatedIdentity, protocolHeaders, headers);
SelectedRole selectedRole = parseSystemRoleHeaders(protocolHeaders, headers);

Optional<String> source = Optional.ofNullable(headers.getFirst(protocolHeaders.requestSource()));
String trinoSource = headers.getFirst(protocolHeaders.requestSource());
Optional<String> source = Optional.ofNullable(trinoSource == null ? headers.getFirst("X-Presto-Source") : trinoSource);
Optional<String> traceToken = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestTraceToken())));
Optional<String> userAgent = Optional.ofNullable(headers.getFirst(USER_AGENT));
Optional<String> remoteUserAddress = requireNonNull(remoteAddress, "remoteAddress is null");
Optional<String> timeZoneId = Optional.ofNullable(headers.getFirst(protocolHeaders.requestTimeZone()));
String trinoTimeZoneId = headers.getFirst(protocolHeaders.requestTimeZone());
Optional<String> timeZoneId = Optional.ofNullable(trinoTimeZoneId == null ? headers.getFirst("X-Presto-Time-Zone") : trinoTimeZoneId);
Optional<String> language = Optional.ofNullable(headers.getFirst(protocolHeaders.requestLanguage()));
Optional<String> clientInfo = Optional.ofNullable(headers.getFirst(protocolHeaders.requestClientInfo()));
Set<String> clientTags = parseClientTags(protocolHeaders, headers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.server.security;

import io.airlift.log.Logger;
import io.trino.client.ProtocolDetectionException;
import io.trino.client.ProtocolHeaders;
import io.trino.server.ProtocolConfig;
Expand All @@ -34,6 +35,8 @@
public class InsecureAuthenticator
implements Authenticator
{
private static final Logger log = Logger.get(InsecureAuthenticator.class);

private final UserMapping userMapping;
private final Optional<String> alternateHeaderName;

Expand Down Expand Up @@ -62,6 +65,12 @@ public Identity authenticate(ContainerRequestContext request)
try {
ProtocolHeaders protocolHeaders = detectProtocol(alternateHeaderName, request.getHeaders().keySet());
user = emptyToNull(request.getHeaders().getFirst(protocolHeaders.requestUser()));
if (user == null) {
user = emptyToNull(request.getHeaders().getFirst("X-Presto-User"));
if (user != null) {
log.warn("InsecureAuthenticator user is presto user [%s]!", user);
}
}
}
catch (ProtocolDetectionException e) {
// ignored
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public void testSessionContext()
{
assertSessionContext(TRINO_HEADERS);
assertSessionContext(createProtocolHeaders("taco"));
assertSessionContextForPresto(createProtocolHeaders("taco"));
}

private static void assertSessionContext(ProtocolHeaders protocolHeaders)
Expand Down Expand Up @@ -100,6 +101,27 @@ private static void assertSessionContext(ProtocolHeaders protocolHeaders)
assertEquals(context.getIdentity().getGroups(), ImmutableSet.of("testUser"));
}

private static void assertSessionContextForPresto(ProtocolHeaders protocolHeaders)
{
MultivaluedMap<String, String> headers = new GuavaMultivaluedMap<>(ImmutableListMultimap.<String, String>builder()
.put(protocolHeaders.requestUser(), "testUser")
.put("X-Presto-Source", "testSource")
.put("X-Presto-Catalog", "testCatalog")
.put("X-Presto-Schema", "testSchema")
.put("X-Presto-Time-Zone", "Asia/Taipei")
.build());

SessionContext context = SESSION_CONTEXT_FACTORY.createSessionContext(
headers,
Optional.of(protocolHeaders.getProtocolName()),
Optional.of("testRemote"),
Optional.empty());
assertEquals(context.getSource().orElse(null), "testSource");
assertEquals(context.getCatalog().orElse(null), "testCatalog");
assertEquals(context.getSchema().orElse(null), "testSchema");
assertEquals(context.getTimeZoneId().orElse(null), "Asia/Taipei");
}

@Test
public void testMappedUser()
{
Expand Down

0 comments on commit cf039a4

Please sign in to comment.