From cf039a46ed025528b2378c8833694834844f4eb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=91=E9=98=BF=E9=B2=B2?= Date: Fri, 13 Aug 2021 14:47:11 +0800 Subject: [PATCH] [#4] allow presto jdbc user to access trino server correctly --- .../HttpRequestSessionContextFactory.java | 14 ++++++++---- .../security/InsecureAuthenticator.java | 9 ++++++++ .../TestHttpRequestSessionContextFactory.java | 22 +++++++++++++++++++ 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java b/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java index cffd39c98716..1570378f58d8 100644 --- a/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java +++ b/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java @@ -98,8 +98,12 @@ public SessionContext createSessionContext( catch (ProtocolDetectionException e) { throw badRequest(e.getMessage()); } - Optional catalog = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestCatalog()))); - Optional schema = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestSchema()))); + String trinoCatalog = trimEmptyToNull(headers.getFirst(protocolHeaders.requestCatalog())); + Optional catalog = Optional.ofNullable( + trinoCatalog == null ? trimEmptyToNull(headers.getFirst("X-Presto-Catalog")) : trinoCatalog); + String trinoSchema = trimEmptyToNull(headers.getFirst(protocolHeaders.requestSchema())); + Optional schema = Optional.ofNullable( + trinoSchema == null ? trimEmptyToNull(headers.getFirst("X-Presto-Schema")) : trinoSchema); Optional path = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestPath()))); assertRequest((catalog.isPresent()) || (schema.isEmpty()), "Schema is set but catalog is not"); @@ -107,11 +111,13 @@ public SessionContext createSessionContext( Identity identity = buildSessionIdentity(authenticatedIdentity, protocolHeaders, headers); SelectedRole selectedRole = parseSystemRoleHeaders(protocolHeaders, headers); - Optional source = Optional.ofNullable(headers.getFirst(protocolHeaders.requestSource())); + String trinoSource = headers.getFirst(protocolHeaders.requestSource()); + Optional source = Optional.ofNullable(trinoSource == null ? headers.getFirst("X-Presto-Source") : trinoSource); Optional traceToken = Optional.ofNullable(trimEmptyToNull(headers.getFirst(protocolHeaders.requestTraceToken()))); Optional userAgent = Optional.ofNullable(headers.getFirst(USER_AGENT)); Optional remoteUserAddress = requireNonNull(remoteAddress, "remoteAddress is null"); - Optional timeZoneId = Optional.ofNullable(headers.getFirst(protocolHeaders.requestTimeZone())); + String trinoTimeZoneId = headers.getFirst(protocolHeaders.requestTimeZone()); + Optional timeZoneId = Optional.ofNullable(trinoTimeZoneId == null ? headers.getFirst("X-Presto-Time-Zone") : trinoTimeZoneId); Optional language = Optional.ofNullable(headers.getFirst(protocolHeaders.requestLanguage())); Optional clientInfo = Optional.ofNullable(headers.getFirst(protocolHeaders.requestClientInfo())); Set clientTags = parseClientTags(protocolHeaders, headers); diff --git a/core/trino-main/src/main/java/io/trino/server/security/InsecureAuthenticator.java b/core/trino-main/src/main/java/io/trino/server/security/InsecureAuthenticator.java index b6540c6fa8ee..6f31e44adfad 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/InsecureAuthenticator.java +++ b/core/trino-main/src/main/java/io/trino/server/security/InsecureAuthenticator.java @@ -13,6 +13,7 @@ */ package io.trino.server.security; +import io.airlift.log.Logger; import io.trino.client.ProtocolDetectionException; import io.trino.client.ProtocolHeaders; import io.trino.server.ProtocolConfig; @@ -34,6 +35,8 @@ public class InsecureAuthenticator implements Authenticator { + private static final Logger log = Logger.get(InsecureAuthenticator.class); + private final UserMapping userMapping; private final Optional alternateHeaderName; @@ -62,6 +65,12 @@ public Identity authenticate(ContainerRequestContext request) try { ProtocolHeaders protocolHeaders = detectProtocol(alternateHeaderName, request.getHeaders().keySet()); user = emptyToNull(request.getHeaders().getFirst(protocolHeaders.requestUser())); + if (user == null) { + user = emptyToNull(request.getHeaders().getFirst("X-Presto-User")); + if (user != null) { + log.warn("InsecureAuthenticator user is presto user [%s]!", user); + } + } } catch (ProtocolDetectionException e) { // ignored diff --git a/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java b/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java index 8ef3810ea860..066ecf56b8b8 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java +++ b/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java @@ -47,6 +47,7 @@ public void testSessionContext() { assertSessionContext(TRINO_HEADERS); assertSessionContext(createProtocolHeaders("taco")); + assertSessionContextForPresto(createProtocolHeaders("taco")); } private static void assertSessionContext(ProtocolHeaders protocolHeaders) @@ -100,6 +101,27 @@ private static void assertSessionContext(ProtocolHeaders protocolHeaders) assertEquals(context.getIdentity().getGroups(), ImmutableSet.of("testUser")); } + private static void assertSessionContextForPresto(ProtocolHeaders protocolHeaders) + { + MultivaluedMap headers = new GuavaMultivaluedMap<>(ImmutableListMultimap.builder() + .put(protocolHeaders.requestUser(), "testUser") + .put("X-Presto-Source", "testSource") + .put("X-Presto-Catalog", "testCatalog") + .put("X-Presto-Schema", "testSchema") + .put("X-Presto-Time-Zone", "Asia/Taipei") + .build()); + + SessionContext context = SESSION_CONTEXT_FACTORY.createSessionContext( + headers, + Optional.of(protocolHeaders.getProtocolName()), + Optional.of("testRemote"), + Optional.empty()); + assertEquals(context.getSource().orElse(null), "testSource"); + assertEquals(context.getCatalog().orElse(null), "testCatalog"); + assertEquals(context.getSchema().orElse(null), "testSchema"); + assertEquals(context.getTimeZoneId().orElse(null), "Asia/Taipei"); + } + @Test public void testMappedUser() {