Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code Refactoring #533

Merged
merged 4 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions gateway-ha/src/main/java/io/trino/gateway/baseapp/BaseApp.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.trino.gateway.baseapp;

import com.google.common.collect.MoreCollectors;
import com.google.inject.Binder;
import com.google.inject.Module;
import com.google.inject.Scopes;
Expand Down Expand Up @@ -44,6 +43,7 @@
import java.util.List;
import java.util.Optional;

import static com.google.common.collect.MoreCollectors.toOptional;
import static io.airlift.http.client.HttpClientBinder.httpClientBinder;
import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder;
import static java.lang.String.format;
Expand All @@ -54,11 +54,11 @@ public class BaseApp
implements Module
{
private static final Logger logger = Logger.get(BaseApp.class);
private final HaGatewayConfiguration haGatewayConfiguration;
private final HaGatewayConfiguration configuration;

public BaseApp(HaGatewayConfiguration haGatewayConfiguration)
public BaseApp(HaGatewayConfiguration configuration)
{
this.haGatewayConfiguration = requireNonNull(haGatewayConfiguration);
this.configuration = requireNonNull(configuration, "configuration is null");
}

private static Module newModule(String clazz, HaGatewayConfiguration configuration)
Expand Down Expand Up @@ -89,7 +89,7 @@ private static void validateModules(List<Module> modules, HaGatewayConfiguration
{
Optional<Module> routerProvider = modules.stream()
.filter(module -> module instanceof RouterBaseModule)
.collect(MoreCollectors.toOptional());
.collect(toOptional());
if (routerProvider.isEmpty()) {
logger.warn("Router provider doesn't exist in the config, using the StochasticRoutingManagerProvider");
String clazz = StochasticRoutingManagerProvider.class.getCanonicalName();
Expand All @@ -116,12 +116,12 @@ public static List<Module> addModules(HaGatewayConfiguration configuration)
@Override
public void configure(Binder binder)
{
binder.bind(HaGatewayConfiguration.class).toInstance(haGatewayConfiguration);
binder.bind(HaGatewayConfiguration.class).toInstance(configuration);
registerAuthFilters(binder);
registerResources(binder);
registerProxyResources(binder);
jaxrsBinder(binder).bind(RoutingTargetHandler.class);
addManagedApps(this.haGatewayConfiguration, binder);
addManagedApps(configuration, binder);
jaxrsBinder(binder).bind(AuthorizedExceptionMapper.class);
binder.bind(ProxyHandlerStats.class).in(Scopes.SINGLETON);
newExporter(binder).export(ProxyHandlerStats.class).withGeneratedName();
Expand All @@ -136,7 +136,7 @@ private static void addManagedApps(HaGatewayConfiguration configuration, Binder
configuration.getManagedApps().forEach(
clazz -> {
try {
Class c = Class.forName(clazz);
Class<?> c = Class.forName(clazz);
binder.bind(c).in(Scopes.SINGLETON);
}
catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

import static java.util.Objects.requireNonNull;

public class ActiveClusterMonitor
{
public static final int MONITOR_TASK_DELAY_SECONDS = 60;
Expand All @@ -39,9 +42,8 @@ public class ActiveClusterMonitor

private final int taskDelaySeconds;
private final ClusterStatsMonitor clusterStatsMonitor;
private volatile boolean monitorActive = true;
private final ExecutorService executorService = Executors.newFixedThreadPool(DEFAULT_THREAD_POOL_SIZE);
private final ExecutorService singleTaskExecutor = Executors.newSingleThreadExecutor();
private final ScheduledExecutorService scheduledExecutor = Executors.newSingleThreadScheduledExecutor();

@Inject
public ActiveClusterMonitor(
Expand All @@ -50,65 +52,54 @@ public ActiveClusterMonitor(
MonitorConfiguration monitorConfiguration,
ClusterStatsMonitor clusterStatsMonitor)
{
this.clusterStatsObservers = clusterStatsObservers;
this.gatewayBackendManager = gatewayBackendManager;
this.clusterStatsMonitor = requireNonNull(clusterStatsMonitor, "clusterStatsMonitor is null");
this.clusterStatsObservers = requireNonNull(clusterStatsObservers, "clusterStatsObservers is null");
this.gatewayBackendManager = requireNonNull(gatewayBackendManager, "gatewayBackendManager is null");
this.taskDelaySeconds = monitorConfiguration.getTaskDelaySeconds();
this.clusterStatsMonitor = clusterStatsMonitor;
log.info("Running cluster monitor with connection task delay of %d seconds", taskDelaySeconds);
}

/**
* Run an app that queries all active trino clusters for stats.
*/
@PostConstruct
public void start()
{
singleTaskExecutor.submit(
() -> {
while (monitorActive) {
try {
log.info("Getting the stats for the active clusters");
List<ProxyBackendConfiguration> activeClusters =
gatewayBackendManager.getAllActiveBackends();
List<Future<ClusterStats>> futures = new ArrayList<>();
for (ProxyBackendConfiguration backend : activeClusters) {
Future<ClusterStats> call =
executorService.submit(() -> clusterStatsMonitor.monitor(backend));
futures.add(call);
}
List<ClusterStats> stats = new ArrayList<>();
for (Future<ClusterStats> clusterStatsFuture : futures) {
ClusterStats clusterStats = clusterStatsFuture.get();
stats.add(clusterStats);
}
log.info("Running cluster monitor with connection task delay of %d seconds", taskDelaySeconds);
scheduledExecutor.scheduleAtFixedRate(() -> {
try {
log.info("Getting stats for all active clusters");
List<ProxyBackendConfiguration> activeClusters =
gatewayBackendManager.getAllActiveBackends();
List<Future<ClusterStats>> futures = new ArrayList<>();
for (ProxyBackendConfiguration backend : activeClusters) {
Future<ClusterStats> call = executorService.submit(() -> clusterStatsMonitor.monitor(backend));
futures.add(call);
}
List<ClusterStats> stats = new ArrayList<>();
for (Future<ClusterStats> clusterStatsFuture : futures) {
ClusterStats clusterStats = clusterStatsFuture.get();
stats.add(clusterStats);
}

if (clusterStatsObservers != null) {
for (TrinoClusterStatsObserver observer : clusterStatsObservers) {
observer.observe(stats);
}
}
}
catch (Exception e) {
log.error(e, "Error performing backend monitor tasks");
}
try {
Thread.sleep(TimeUnit.SECONDS.toMillis(taskDelaySeconds));
}
catch (Exception e) {
log.error(e, "Error with monitor task");
}
if (clusterStatsObservers != null) {
for (TrinoClusterStatsObserver observer : clusterStatsObservers) {
observer.observe(stats);
}
});
}
}
catch (Exception e) {
log.error(e, "Error performing backend monitor tasks");
}
try {
Thread.sleep(TimeUnit.SECONDS.toMillis(taskDelaySeconds));
}
catch (Exception e) {
log.error(e, "Error with monitor task");
}
}, 0, taskDelaySeconds, TimeUnit.SECONDS);
}

/**
* Shut down the app.
*/
@PreDestroy
public void stop()
{
this.monitorActive = false;
this.executorService.shutdown();
this.singleTaskExecutor.shutdown();
executorService.shutdownNow();
scheduledExecutor.shutdownNow();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Strings;
import io.airlift.http.client.HttpStatus;
import io.airlift.log.Logger;
import io.trino.gateway.ha.config.BackendStateConfiguration;
Expand All @@ -33,15 +32,18 @@
import java.util.List;
import java.util.Map;

import static com.google.common.base.Strings.isNullOrEmpty;
import static io.airlift.http.client.HttpStatus.fromStatusCode;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.UI_API_QUEUED_LIST_PATH;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.UI_API_STATS_PATH;
import static io.trino.gateway.ha.handler.QueryIdCachingProxyHandler.UI_LOGIN_PATH;
import static java.util.Objects.requireNonNull;

public class ClusterStatsHttpMonitor
implements ClusterStatsMonitor
{
private static final Logger log = Logger.get(ClusterStatsHttpMonitor.class);
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final String SESSION_USER = "sessionUser";

private final String username;
Expand All @@ -59,14 +61,13 @@ public ClusterStats monitor(ProxyBackendConfiguration backend)
ClusterStats.Builder clusterStats = ClusterStatsMonitor.getClusterStatsBuilder(backend);
// Fetch Cluster level Stats.
String response = queryCluster(backend, UI_API_STATS_PATH);
if (Strings.isNullOrEmpty(response)) {
if (isNullOrEmpty(response)) {
log.error("Received null/empty response for %s", UI_API_STATS_PATH);
return clusterStats.build();
}

try {
HashMap<String, Object> result = new ObjectMapper().readValue(response, HashMap.class);

HashMap<String, Object> result = OBJECT_MAPPER.readValue(response, new TypeReference<>() {});
int activeWorkers = (int) result.get("activeWorkers");
clusterStats
.numWorkerNodes(activeWorkers)
Expand All @@ -84,18 +85,14 @@ public ClusterStats monitor(ProxyBackendConfiguration backend)
// Fetch User Level Stats.
Map<String, Integer> clusterUserStats = new HashMap<>();
response = queryCluster(backend, UI_API_QUEUED_LIST_PATH);
if (Strings.isNullOrEmpty(response)) {
if (isNullOrEmpty(response)) {
log.error("Received null/empty response for %s", UI_API_QUEUED_LIST_PATH);
return clusterStats.build();
}
try {
List<Map<String, Object>> queries = new ObjectMapper().readValue(response,
new TypeReference<List<Map<String, Object>>>()
{
});

for (Map<String, Object> q : queries) {
String user = (String) q.get(SESSION_USER);
List<Map<String, Object>> queries = OBJECT_MAPPER.readValue(response, new TypeReference<>() {});
for (Map<String, Object> query : queries) {
String user = (String) query.get(SESSION_USER);
clusterUserStats.put(user, clusterUserStats.getOrDefault(user, 0) + 1);
}
}
Expand Down Expand Up @@ -148,19 +145,15 @@ private String queryCluster(ProxyBackendConfiguration backend, String path)
Call call = client.newCall(request);

try (Response res = call.execute()) {
switch (fromStatusCode(res.code())) {
case HttpStatus.OK:
return res.body().string();
case HttpStatus.UNAUTHORIZED:
return switch (fromStatusCode(res.code())) {
case HttpStatus.OK -> requireNonNull(res.body(), "body is null").string();
case HttpStatus.UNAUTHORIZED -> {
log.info("Unauthorized to fetch cluster stats");
log.debug("username: %s, targetUrl: %s, cookieStore: %s",
username,
targetUrl,
client.cookieJar().loadForRequest(HttpUrl.parse(targetUrl)));
return null;
default:
return null;
}
log.debug("username: %s, targetUrl: %s, cookieStore: %s", username, targetUrl, client.cookieJar().loadForRequest(HttpUrl.parse(targetUrl)));
yield null;
}
default -> null;
};
}
catch (IOException e) {
log.warn(e, "Failed to fetch cluster stats");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ private TrinoStatus checkStatus(String baseUrl, int retriesRemaining)
.build();
try {
ServerInfo serverInfo = client.execute(request, SERVER_INFO_JSON_RESPONSE_HANDLER);
return serverInfo.isStarting() ? TrinoStatus.PENDING : TrinoStatus.HEALTHY;
return serverInfo.starting() ? TrinoStatus.PENDING : TrinoStatus.HEALTHY;
}
catch (UnexpectedResponseException e) {
if (shouldRetry(e.getStatusCode())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ public ClusterStats monitor(ProxyBackendConfiguration backend)
return clusterStats.build(); // TODO Invalid configuration should fail
}

try (Connection conn = DriverManager.getConnection(jdbcUrl, properties)) {
PreparedStatement stmt = SimpleTimeLimiter.create(Executors.newSingleThreadExecutor()).callWithTimeout(
() -> conn.prepareStatement(STATE_QUERY), 10, TimeUnit.SECONDS);
stmt.setString(1, (String) properties.get("user"));
try (Connection conn = DriverManager.getConnection(jdbcUrl, properties);
PreparedStatement statement = SimpleTimeLimiter.create(Executors.newSingleThreadExecutor()).callWithTimeout(
() -> conn.prepareStatement(STATE_QUERY), 10, TimeUnit.SECONDS)) {
statement.setString(1, (String) properties.get("user"));
Map<String, Integer> partialState = new HashMap<>();
ResultSet rs = stmt.executeQuery();
ResultSet rs = statement.executeQuery();
while (rs.next()) {
partialState.put(rs.getString("state"), rs.getInt("count"));
}
Expand All @@ -91,10 +91,10 @@ public ClusterStats monitor(ProxyBackendConfiguration backend)
.build();
}
catch (TimeoutException e) {
log.error(e, "timed out fetching status for %s backend", url);
log.error(e, "Timed out fetching status for %s backend", url);
}
catch (Exception e) {
log.error(e, "could not fetch status for %s backend", url);
log.error(e, "Could not fetch status for %s backend", url);
}
return clusterStats.build();
}
Expand Down

This file was deleted.

Loading