Skip to content

Commit

Permalink
Inject OAuth2GatewayCookie
Browse files Browse the repository at this point in the history
  • Loading branch information
willmostly committed Oct 30, 2024
1 parent cbfa643 commit afbb9a1
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.trino.gateway.ha.resource.LoginResource;
import io.trino.gateway.ha.resource.PublicResource;
import io.trino.gateway.ha.resource.TrinoResource;
import io.trino.gateway.ha.router.OAuth2GatewayCookieProvider;
import io.trino.gateway.ha.security.AuthorizedExceptionMapper;
import io.trino.gateway.proxyserver.ForProxy;
import io.trino.gateway.proxyserver.ProxyRequestHandler;
Expand Down Expand Up @@ -121,6 +122,7 @@ public void configure(Binder binder)
registerResources(binder);
registerProxyResources(binder);
jaxrsBinder(binder).bind(RoutingTargetHandler.class);
jaxrsBinder(binder).bind(OAuth2GatewayCookieProvider.class);
addManagedApps(configuration, binder);
jaxrsBinder(binder).bind(AuthorizedExceptionMapper.class);
binder.bind(ProxyHandlerStats.class).in(Scopes.SINGLETON);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@

public class OAuth2GatewayCookieConfiguration
{
// Configuration initialization using dropwizard requires
// instance method setters. Values are global, and can be accessed using static getters
private List<String> routingPaths = ImmutableList.of("/oauth2");
private List<String> deletePaths = ImmutableList.of("/logout", "/oauth2/logout");
private Duration lifetime = Duration.valueOf("10m");

Expand All @@ -36,16 +33,6 @@ public void setDeletePaths(List<String> deletePaths)
this.deletePaths = deletePaths;
}

public List<String> getRoutingPaths()
{
return routingPaths;
}

public void setRoutingPaths(List<String> routingPaths)
{
this.routingPaths = routingPaths;
}

public Duration getLifetime()
{
return lifetime;
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import io.trino.gateway.ha.config.AuthorizationConfiguration;
import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider;
import io.trino.gateway.ha.config.HaGatewayConfiguration;
import io.trino.gateway.ha.config.OAuth2GatewayCookieConfigurationPropertiesProvider;
import io.trino.gateway.ha.config.RoutingRulesConfiguration;
import io.trino.gateway.ha.config.RulesExternalConfiguration;
import io.trino.gateway.ha.config.UserConfiguration;
Expand Down Expand Up @@ -79,9 +78,6 @@ public HaGatewayProviderModule(HaGatewayConfiguration configuration)

GatewayCookieConfigurationPropertiesProvider gatewayCookieConfigurationPropertiesProvider = GatewayCookieConfigurationPropertiesProvider.getInstance();
gatewayCookieConfigurationPropertiesProvider.initialize(configuration.getGatewayCookieConfiguration());

OAuth2GatewayCookieConfigurationPropertiesProvider oAuth2GatewayCookieConfigurationPropertiesProvider = OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance();
oAuth2GatewayCookieConfigurationPropertiesProvider.initialize(configuration.getOauth2GatewayCookieConfiguration());
}

private LbOAuthManager getOAuthManager(HaGatewayConfiguration configuration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
*/
package io.trino.gateway.ha.router;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Streams;
import io.trino.gateway.ha.config.OAuth2GatewayCookieConfigurationPropertiesProvider;
import io.airlift.units.Duration;

import java.util.List;
import java.util.stream.Stream;

public class OAuth2GatewayCookie
Expand All @@ -25,15 +25,8 @@ public class OAuth2GatewayCookie
public static final String NAME = GatewayCookie.PREFIX + "OAUTH2";
public static final String OAUTH2_PATH = "/oauth2";

public OAuth2GatewayCookie(String backend)
public OAuth2GatewayCookie(String backend, List<String> deletePaths, Duration ttl)
{
super(
NAME,
null,
backend,
ImmutableList.copyOf(Streams.concat(Stream.of(OAUTH2_PATH), OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getDeletePaths().stream()).toList()),
OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getDeletePaths(),
OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getLifetime(),
0);
super(NAME, null, backend, Streams.concat(Stream.of(OAUTH2_PATH), deletePaths.stream()).toList(), deletePaths, ttl, 0);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.ha.router;

import com.google.inject.AbstractModule;
import com.google.inject.Inject;
import com.google.inject.Provides;
import io.airlift.units.Duration;
import io.trino.gateway.ha.config.HaGatewayConfiguration;

import java.util.List;

public class OAuth2GatewayCookieProvider
extends AbstractModule
{
private final List<String> deletePaths;
private final Duration ttl;

@Inject
public OAuth2GatewayCookieProvider(HaGatewayConfiguration configuration)
{
this.deletePaths = configuration.getOauth2GatewayCookieConfiguration().getDeletePaths();
this.ttl = configuration.getOauth2GatewayCookieConfiguration().getLifetime();
}

@Provides
public OAuth2GatewayCookie getOAuth2GatewayCookie(String backend)
{
return new OAuth2GatewayCookie(backend, deletePaths, ttl);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.trino.gateway.ha.config.HaGatewayConfiguration;
import io.trino.gateway.ha.router.GatewayCookie;
import io.trino.gateway.ha.router.OAuth2GatewayCookie;
import io.trino.gateway.ha.router.OAuth2GatewayCookieProvider;
import io.trino.gateway.ha.router.QueryHistoryManager;
import io.trino.gateway.ha.router.RoutingManager;
import io.trino.gateway.ha.router.TrinoRequestUser;
Expand Down Expand Up @@ -86,13 +87,15 @@ public class ProxyRequestHandler
private final List<String> statementPaths;
private final boolean includeClusterInfoInResponse;
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;
private final OAuth2GatewayCookieProvider oAuth2GatewayCookieProvider;

@Inject
public ProxyRequestHandler(
@ForProxy HttpClient httpClient,
RoutingManager routingManager,
QueryHistoryManager queryHistoryManager,
HaGatewayConfiguration haGatewayConfiguration)
HaGatewayConfiguration haGatewayConfiguration,
OAuth2GatewayCookieProvider oAuth2GatewayCookieProvider)
{
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.routingManager = requireNonNull(routingManager, "routingManager is null");
Expand All @@ -103,6 +106,7 @@ public ProxyRequestHandler(
addXForwardedHeaders = haGatewayConfiguration.getRouting().isAddXForwardedHeaders();
statementPaths = haGatewayConfiguration.getStatementPaths();
this.includeClusterInfoInResponse = haGatewayConfiguration.isIncludeClusterHostInResponse();
this.oAuth2GatewayCookieProvider = requireNonNull(oAuth2GatewayCookieProvider, "oAuth2GatewayCookieProvider is null");
}

@PreDestroy
Expand Down Expand Up @@ -193,7 +197,7 @@ private ImmutableList<NewCookie> getOAuth2GatewayCookie(URI remoteUri, HttpServl
if (remoteUri.getPath().startsWith(OAuth2GatewayCookie.OAUTH2_PATH)
&& !(servletRequest.getCookies() != null
&& Arrays.stream(servletRequest.getCookies()).anyMatch(c -> c.getName().equals(OAuth2GatewayCookie.NAME)))) {
GatewayCookie oauth2Cookie = new OAuth2GatewayCookie(getRemoteTarget(remoteUri));
GatewayCookie oauth2Cookie = oAuth2GatewayCookieProvider.getOAuth2GatewayCookie(getRemoteTarget(remoteUri));
return ImmutableList.of(oauth2Cookie.toNewCookie());
}
else if (servletRequest.getCookies() != null) {
Expand Down

0 comments on commit afbb9a1

Please sign in to comment.