Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inject OAuth2GatewayCookie #539

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.trino.gateway.ha.resource.LoginResource;
import io.trino.gateway.ha.resource.PublicResource;
import io.trino.gateway.ha.resource.TrinoResource;
import io.trino.gateway.ha.router.OAuth2GatewayCookieProvider;
import io.trino.gateway.ha.security.AuthorizedExceptionMapper;
import io.trino.gateway.proxyserver.ForProxy;
import io.trino.gateway.proxyserver.ProxyRequestHandler;
Expand Down Expand Up @@ -121,6 +122,7 @@ public void configure(Binder binder)
registerResources(binder);
registerProxyResources(binder);
jaxrsBinder(binder).bind(RoutingTargetHandler.class);
binder.bind(OAuth2GatewayCookieProvider.class);
addManagedApps(configuration, binder);
jaxrsBinder(binder).bind(AuthorizedExceptionMapper.class);
binder.bind(ProxyHandlerStats.class).in(Scopes.SINGLETON);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@

public class OAuth2GatewayCookieConfiguration
{
// Configuration initialization using dropwizard requires
// instance method setters. Values are global, and can be accessed using static getters
private List<String> routingPaths = ImmutableList.of("/oauth2");
private List<String> deletePaths = ImmutableList.of("/logout", "/oauth2/logout");
private Duration lifetime = Duration.valueOf("10m");

Expand All @@ -36,16 +33,6 @@ public void setDeletePaths(List<String> deletePaths)
this.deletePaths = deletePaths;
}

public List<String> getRoutingPaths()
{
return routingPaths;
}

public void setRoutingPaths(List<String> routingPaths)
{
this.routingPaths = routingPaths;
}

public Duration getLifetime()
{
return lifetime;
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import io.trino.gateway.ha.config.AuthorizationConfiguration;
import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider;
import io.trino.gateway.ha.config.HaGatewayConfiguration;
import io.trino.gateway.ha.config.OAuth2GatewayCookieConfigurationPropertiesProvider;
import io.trino.gateway.ha.config.RoutingRulesConfiguration;
import io.trino.gateway.ha.config.RulesExternalConfiguration;
import io.trino.gateway.ha.config.UserConfiguration;
Expand Down Expand Up @@ -79,9 +78,6 @@ public HaGatewayProviderModule(HaGatewayConfiguration configuration)

GatewayCookieConfigurationPropertiesProvider gatewayCookieConfigurationPropertiesProvider = GatewayCookieConfigurationPropertiesProvider.getInstance();
gatewayCookieConfigurationPropertiesProvider.initialize(configuration.getGatewayCookieConfiguration());

OAuth2GatewayCookieConfigurationPropertiesProvider oAuth2GatewayCookieConfigurationPropertiesProvider = OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance();
oAuth2GatewayCookieConfigurationPropertiesProvider.initialize(configuration.getOauth2GatewayCookieConfiguration());
}

private LbOAuthManager getOAuthManager(HaGatewayConfiguration configuration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,18 @@
package io.trino.gateway.ha.router;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Streams;
import io.trino.gateway.ha.config.OAuth2GatewayCookieConfigurationPropertiesProvider;
import io.airlift.units.Duration;

import java.util.stream.Stream;
import java.util.List;

public class OAuth2GatewayCookie
extends GatewayCookie
{
public static final String NAME = GatewayCookie.PREFIX + "OAUTH2";
public static final String OAUTH2_PATH = "/oauth2";

public OAuth2GatewayCookie(String backend)
public OAuth2GatewayCookie(String backend, List<String> deletePaths, Duration ttl)
{
super(
NAME,
null,
backend,
ImmutableList.copyOf(Streams.concat(Stream.of(OAUTH2_PATH), OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getDeletePaths().stream()).toList()),
OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getDeletePaths(),
OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getLifetime(),
0);
super(NAME, null, backend, new ImmutableList.Builder<String>().add(OAUTH2_PATH).addAll(deletePaths).build(), deletePaths, ttl, 0);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.ha.router;

import com.google.common.collect.ImmutableList;
import com.google.inject.AbstractModule;
import com.google.inject.Inject;
import io.airlift.units.Duration;
import io.trino.gateway.ha.config.HaGatewayConfiguration;

import java.util.List;

import static java.util.Objects.requireNonNull;

public class OAuth2GatewayCookieProvider
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why we need a new class as the logic is pretty simple. Why not putting deletePaths & ttl on fields in ProxyRequestHandler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do this. It seemed cleaner to keep the oauth2 cookie configuration out of ProxyRequestHandler since it does not use it directly. How do you typically draw the line between injecting vs plumbing configuration properties through by hand?

extends AbstractModule
{
private final List<String> deletePaths;
private final Duration ttl;

@Inject
public OAuth2GatewayCookieProvider(HaGatewayConfiguration configuration)
{
requireNonNull(configuration.getOauth2GatewayCookieConfiguration(), "OAuth2GatewatCookieConfiguration is null");
this.deletePaths = ImmutableList.copyOf(configuration.getOauth2GatewayCookieConfiguration().getDeletePaths());
this.ttl = configuration.getOauth2GatewayCookieConfiguration().getLifetime();
}

public OAuth2GatewayCookie getOAuth2GatewayCookie(String backend)
{
return new OAuth2GatewayCookie(backend, deletePaths, ttl);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.trino.gateway.ha.config.HaGatewayConfiguration;
import io.trino.gateway.ha.router.GatewayCookie;
import io.trino.gateway.ha.router.OAuth2GatewayCookie;
import io.trino.gateway.ha.router.OAuth2GatewayCookieProvider;
import io.trino.gateway.ha.router.QueryHistoryManager;
import io.trino.gateway.ha.router.RoutingManager;
import io.trino.gateway.ha.router.TrinoRequestUser;
Expand Down Expand Up @@ -86,13 +87,15 @@ public class ProxyRequestHandler
private final List<String> statementPaths;
private final boolean includeClusterInfoInResponse;
private final TrinoRequestUser.TrinoRequestUserProvider trinoRequestUserProvider;
private final OAuth2GatewayCookieProvider oAuth2GatewayCookieProvider;

@Inject
public ProxyRequestHandler(
@ForProxy HttpClient httpClient,
RoutingManager routingManager,
QueryHistoryManager queryHistoryManager,
HaGatewayConfiguration haGatewayConfiguration)
HaGatewayConfiguration haGatewayConfiguration,
OAuth2GatewayCookieProvider oAuth2GatewayCookieProvider)
{
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.routingManager = requireNonNull(routingManager, "routingManager is null");
Expand All @@ -103,6 +106,7 @@ public ProxyRequestHandler(
addXForwardedHeaders = haGatewayConfiguration.getRouting().isAddXForwardedHeaders();
statementPaths = haGatewayConfiguration.getStatementPaths();
this.includeClusterInfoInResponse = haGatewayConfiguration.isIncludeClusterHostInResponse();
this.oAuth2GatewayCookieProvider = requireNonNull(oAuth2GatewayCookieProvider, "oAuth2GatewayCookieProvider is null");
}

@PreDestroy
Expand Down Expand Up @@ -193,7 +197,7 @@ private ImmutableList<NewCookie> getOAuth2GatewayCookie(URI remoteUri, HttpServl
if (remoteUri.getPath().startsWith(OAuth2GatewayCookie.OAUTH2_PATH)
&& !(servletRequest.getCookies() != null
&& Arrays.stream(servletRequest.getCookies()).anyMatch(c -> c.getName().equals(OAuth2GatewayCookie.NAME)))) {
GatewayCookie oauth2Cookie = new OAuth2GatewayCookie(getRemoteTarget(remoteUri));
GatewayCookie oauth2Cookie = oAuth2GatewayCookieProvider.getOAuth2GatewayCookie(getRemoteTarget(remoteUri));
return ImmutableList.of(oauth2Cookie.toNewCookie());
}
else if (servletRequest.getCookies() != null) {
Expand Down