From afbb9a1a9f9e759422266c1f11af548224fab474 Mon Sep 17 00:00:00 2001 From: Will Morrison Date: Wed, 30 Oct 2024 08:44:38 -0400 Subject: [PATCH] Inject OAuth2GatewayCookie --- .../io/trino/gateway/baseapp/BaseApp.java | 2 + .../OAuth2GatewayCookieConfiguration.java | 13 ---- ...CookieConfigurationPropertiesProvider.java | 63 ------------------- .../ha/module/HaGatewayProviderModule.java | 4 -- .../ha/router/OAuth2GatewayCookie.java | 15 ++--- .../router/OAuth2GatewayCookieProvider.java | 42 +++++++++++++ .../proxyserver/ProxyRequestHandler.java | 8 ++- 7 files changed, 54 insertions(+), 93 deletions(-) delete mode 100644 gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfigurationPropertiesProvider.java create mode 100644 gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookieProvider.java diff --git a/gateway-ha/src/main/java/io/trino/gateway/baseapp/BaseApp.java b/gateway-ha/src/main/java/io/trino/gateway/baseapp/BaseApp.java index aebdf4a67..b359ba461 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/baseapp/BaseApp.java +++ b/gateway-ha/src/main/java/io/trino/gateway/baseapp/BaseApp.java @@ -31,6 +31,7 @@ import io.trino.gateway.ha.resource.LoginResource; import io.trino.gateway.ha.resource.PublicResource; import io.trino.gateway.ha.resource.TrinoResource; +import io.trino.gateway.ha.router.OAuth2GatewayCookieProvider; import io.trino.gateway.ha.security.AuthorizedExceptionMapper; import io.trino.gateway.proxyserver.ForProxy; import io.trino.gateway.proxyserver.ProxyRequestHandler; @@ -121,6 +122,7 @@ public void configure(Binder binder) registerResources(binder); registerProxyResources(binder); jaxrsBinder(binder).bind(RoutingTargetHandler.class); + jaxrsBinder(binder).bind(OAuth2GatewayCookieProvider.class); addManagedApps(configuration, binder); jaxrsBinder(binder).bind(AuthorizedExceptionMapper.class); binder.bind(ProxyHandlerStats.class).in(Scopes.SINGLETON); diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfiguration.java index d5224a16e..9ed2f29c1 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfiguration.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfiguration.java @@ -20,9 +20,6 @@ public class OAuth2GatewayCookieConfiguration { - // Configuration initialization using dropwizard requires - // instance method setters. Values are global, and can be accessed using static getters - private List routingPaths = ImmutableList.of("/oauth2"); private List deletePaths = ImmutableList.of("/logout", "/oauth2/logout"); private Duration lifetime = Duration.valueOf("10m"); @@ -36,16 +33,6 @@ public void setDeletePaths(List deletePaths) this.deletePaths = deletePaths; } - public List getRoutingPaths() - { - return routingPaths; - } - - public void setRoutingPaths(List routingPaths) - { - this.routingPaths = routingPaths; - } - public Duration getLifetime() { return lifetime; diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfigurationPropertiesProvider.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfigurationPropertiesProvider.java deleted file mode 100644 index dfbd14560..000000000 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfigurationPropertiesProvider.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.gateway.ha.config; - -import io.airlift.units.Duration; - -import java.util.List; - -public class OAuth2GatewayCookieConfigurationPropertiesProvider -{ - private static final OAuth2GatewayCookieConfigurationPropertiesProvider instance = new OAuth2GatewayCookieConfigurationPropertiesProvider(); - - private OAuth2GatewayCookieConfiguration oAuth2GatewayCookieConfiguration; - - private OAuth2GatewayCookieConfigurationPropertiesProvider() - {} - - public static OAuth2GatewayCookieConfigurationPropertiesProvider getInstance() - { - return instance; - } - - public void initialize(OAuth2GatewayCookieConfiguration oAuth2GatewayCookieConfiguration) - { - this.oAuth2GatewayCookieConfiguration = oAuth2GatewayCookieConfiguration; - } - - public List getDeletePaths() - { - ensureInitialized(); - return oAuth2GatewayCookieConfiguration.getDeletePaths(); - } - - public List getRoutingPaths() - { - ensureInitialized(); - return oAuth2GatewayCookieConfiguration.getRoutingPaths(); - } - - public Duration getLifetime() - { - ensureInitialized(); - return oAuth2GatewayCookieConfiguration.getLifetime(); - } - - private void ensureInitialized() - { - if (oAuth2GatewayCookieConfiguration == null) { - throw new IllegalStateException("getInstance.initialize(OAuth2GatewayCookieConfiguration) must be called before use"); - } - } -} diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java index 550c98daf..2d370ba6f 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java @@ -21,7 +21,6 @@ import io.trino.gateway.ha.config.AuthorizationConfiguration; import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider; import io.trino.gateway.ha.config.HaGatewayConfiguration; -import io.trino.gateway.ha.config.OAuth2GatewayCookieConfigurationPropertiesProvider; import io.trino.gateway.ha.config.RoutingRulesConfiguration; import io.trino.gateway.ha.config.RulesExternalConfiguration; import io.trino.gateway.ha.config.UserConfiguration; @@ -79,9 +78,6 @@ public HaGatewayProviderModule(HaGatewayConfiguration configuration) GatewayCookieConfigurationPropertiesProvider gatewayCookieConfigurationPropertiesProvider = GatewayCookieConfigurationPropertiesProvider.getInstance(); gatewayCookieConfigurationPropertiesProvider.initialize(configuration.getGatewayCookieConfiguration()); - - OAuth2GatewayCookieConfigurationPropertiesProvider oAuth2GatewayCookieConfigurationPropertiesProvider = OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance(); - oAuth2GatewayCookieConfigurationPropertiesProvider.initialize(configuration.getOauth2GatewayCookieConfiguration()); } private LbOAuthManager getOAuthManager(HaGatewayConfiguration configuration) diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookie.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookie.java index 72450ebcc..6fa4c9b71 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookie.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookie.java @@ -13,10 +13,10 @@ */ package io.trino.gateway.ha.router; -import com.google.common.collect.ImmutableList; import com.google.common.collect.Streams; -import io.trino.gateway.ha.config.OAuth2GatewayCookieConfigurationPropertiesProvider; +import io.airlift.units.Duration; +import java.util.List; import java.util.stream.Stream; public class OAuth2GatewayCookie @@ -25,15 +25,8 @@ public class OAuth2GatewayCookie public static final String NAME = GatewayCookie.PREFIX + "OAUTH2"; public static final String OAUTH2_PATH = "/oauth2"; - public OAuth2GatewayCookie(String backend) + public OAuth2GatewayCookie(String backend, List deletePaths, Duration ttl) { - super( - NAME, - null, - backend, - ImmutableList.copyOf(Streams.concat(Stream.of(OAUTH2_PATH), OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getDeletePaths().stream()).toList()), - OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getDeletePaths(), - OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getLifetime(), - 0); + super(NAME, null, backend, Streams.concat(Stream.of(OAUTH2_PATH), deletePaths.stream()).toList(), deletePaths, ttl, 0); } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookieProvider.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookieProvider.java new file mode 100644 index 000000000..1a6d7d686 --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookieProvider.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.router; + +import com.google.inject.AbstractModule; +import com.google.inject.Inject; +import com.google.inject.Provides; +import io.airlift.units.Duration; +import io.trino.gateway.ha.config.HaGatewayConfiguration; + +import java.util.List; + +public class OAuth2GatewayCookieProvider + extends AbstractModule +{ + private final List deletePaths; + private final Duration ttl; + + @Inject + public OAuth2GatewayCookieProvider(HaGatewayConfiguration configuration) + { + this.deletePaths = configuration.getOauth2GatewayCookieConfiguration().getDeletePaths(); + this.ttl = configuration.getOauth2GatewayCookieConfiguration().getLifetime(); + } + + @Provides + public OAuth2GatewayCookie getOAuth2GatewayCookie(String backend) + { + return new OAuth2GatewayCookie(backend, deletePaths, ttl); + } +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java index 859c79325..7f57cc201 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyRequestHandler.java @@ -27,6 +27,7 @@ import io.trino.gateway.ha.config.HaGatewayConfiguration; import io.trino.gateway.ha.router.GatewayCookie; import io.trino.gateway.ha.router.OAuth2GatewayCookie; +import io.trino.gateway.ha.router.OAuth2GatewayCookieProvider; import io.trino.gateway.ha.router.QueryHistoryManager; import io.trino.gateway.ha.router.RoutingManager; import io.trino.gateway.ha.router.TrinoRequestUser; @@ -86,13 +87,15 @@ public class ProxyRequestHandler private final List statementPaths; private final boolean includeClusterInfoInResponse; private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider; + private final OAuth2GatewayCookieProvider oAuth2GatewayCookieProvider; @Inject public ProxyRequestHandler( @ForProxy HttpClient httpClient, RoutingManager routingManager, QueryHistoryManager queryHistoryManager, - HaGatewayConfiguration haGatewayConfiguration) + HaGatewayConfiguration haGatewayConfiguration, + OAuth2GatewayCookieProvider oAuth2GatewayCookieProvider) { this.httpClient = requireNonNull(httpClient, "httpClient is null"); this.routingManager = requireNonNull(routingManager, "routingManager is null"); @@ -103,6 +106,7 @@ public ProxyRequestHandler( addXForwardedHeaders = haGatewayConfiguration.getRouting().isAddXForwardedHeaders(); statementPaths = haGatewayConfiguration.getStatementPaths(); this.includeClusterInfoInResponse = haGatewayConfiguration.isIncludeClusterHostInResponse(); + this.oAuth2GatewayCookieProvider = requireNonNull(oAuth2GatewayCookieProvider, "oAuth2GatewayCookieProvider is null"); } @PreDestroy @@ -193,7 +197,7 @@ private ImmutableList getOAuth2GatewayCookie(URI remoteUri, HttpServl if (remoteUri.getPath().startsWith(OAuth2GatewayCookie.OAUTH2_PATH) && !(servletRequest.getCookies() != null && Arrays.stream(servletRequest.getCookies()).anyMatch(c -> c.getName().equals(OAuth2GatewayCookie.NAME)))) { - GatewayCookie oauth2Cookie = new OAuth2GatewayCookie(getRemoteTarget(remoteUri)); + GatewayCookie oauth2Cookie = oAuth2GatewayCookieProvider.getOAuth2GatewayCookie(getRemoteTarget(remoteUri)); return ImmutableList.of(oauth2Cookie.toNewCookie()); } else if (servletRequest.getCookies() != null) {