Skip to content

Commit

Permalink
Remove mechanism for supporting multiple memory pools
Browse files Browse the repository at this point in the history
  • Loading branch information
losipiuk committed Feb 17, 2022
1 parent 1ee64c7 commit cd2cdda
Show file tree
Hide file tree
Showing 57 changed files with 206 additions and 811 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import java.util.concurrent.Executor;

import static com.google.common.util.concurrent.Futures.immediateVoidFuture;
import static io.trino.memory.LocalMemoryManager.GENERAL_POOL;
import static io.trino.server.DynamicFilterService.DynamicFiltersStats;
import static io.trino.util.Failures.toFailure;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -214,7 +213,6 @@ private static QueryInfo immediateFailureQueryInfo(
session.getQueryId(),
session.toSessionRepresentation(),
QueryState.FAILED,
GENERAL_POOL,
false,
self,
ImmutableList.of(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import io.trino.execution.QueryPreparer.PreparedQuery;
import io.trino.execution.StateMachine.StateChangeListener;
import io.trino.execution.warnings.WarningCollector;
import io.trino.memory.VersionedMemoryPoolId;
import io.trino.server.BasicQueryInfo;
import io.trino.server.protocol.Slug;
import io.trino.spi.QueryId;
Expand Down Expand Up @@ -77,18 +76,6 @@ public Slug getSlug()
return slug;
}

@Override
public VersionedMemoryPoolId getMemoryPool()
{
return stateMachine.getMemoryPool();
}

@Override
public void setMemoryPool(VersionedMemoryPoolId poolId)
{
stateMachine.setMemoryPool(poolId);
}

@Override
public Session getSession()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package io.trino.execution;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
import io.airlift.log.Logger;
import io.trino.FeaturesConfig;
Expand All @@ -26,7 +25,6 @@
import io.trino.operator.OperatorContext;
import io.trino.operator.PipelineContext;
import io.trino.operator.TaskContext;
import io.trino.spi.memory.MemoryPoolId;

import javax.annotation.Nullable;
import javax.annotation.PostConstruct;
Expand All @@ -52,7 +50,7 @@ public class MemoryRevokingScheduler
private static final Logger log = Logger.get(MemoryRevokingScheduler.class);

private static final Ordering<SqlTask> ORDER_BY_CREATE_TIME = Ordering.natural().onResultOf(SqlTask::getTaskCreatedTime);
private final List<MemoryPool> memoryPools;
private final MemoryPool memoryPool;
private final Supplier<? extends Collection<SqlTask>> currentTasksSupplier;
private final ScheduledExecutorService taskManagementExecutor;
private final double memoryRevokingThreshold;
Expand All @@ -73,7 +71,7 @@ public MemoryRevokingScheduler(
FeaturesConfig config)
{
this(
ImmutableList.copyOf(getMemoryPools(localMemoryManager)),
localMemoryManager.getMemoryPool(),
requireNonNull(sqlTaskManager, "sqlTaskManager cannot be null")::getAllTasks,
requireNonNull(taskManagementExecutor, "taskManagementExecutor cannot be null").getExecutor(),
config.getMemoryRevokingThreshold(),
Expand All @@ -82,13 +80,13 @@ public MemoryRevokingScheduler(

@VisibleForTesting
MemoryRevokingScheduler(
List<MemoryPool> memoryPools,
MemoryPool memoryPool,
Supplier<? extends Collection<SqlTask>> currentTasksSupplier,
ScheduledExecutorService taskManagementExecutor,
double memoryRevokingThreshold,
double memoryRevokingTarget)
{
this.memoryPools = ImmutableList.copyOf(requireNonNull(memoryPools, "memoryPools is null"));
this.memoryPool = requireNonNull(memoryPool, "memoryPool is null");
this.currentTasksSupplier = requireNonNull(currentTasksSupplier, "currentTasksSupplier is null");
this.taskManagementExecutor = requireNonNull(taskManagementExecutor, "taskManagementExecutor is null");
this.memoryRevokingThreshold = checkFraction(memoryRevokingThreshold, "memoryRevokingThreshold");
Expand All @@ -106,14 +104,6 @@ private static double checkFraction(double value, String valueName)
return value;
}

private static List<MemoryPool> getMemoryPools(LocalMemoryManager localMemoryManager)
{
requireNonNull(localMemoryManager, "localMemoryManager cannot be null");
ImmutableList.Builder<MemoryPool> builder = new ImmutableList.Builder<>();
builder.add(localMemoryManager.getGeneralPool());
return builder.build();
}

@PostConstruct
public void start()
{
Expand Down Expand Up @@ -141,13 +131,13 @@ public void stop()
scheduledFuture = null;
}

memoryPools.forEach(memoryPool -> memoryPool.removeListener(memoryPoolListener));
memoryPool.removeListener(memoryPoolListener);
}

@VisibleForTesting
void registerPoolListeners()
{
memoryPools.forEach(memoryPool -> memoryPool.addListener(memoryPoolListener));
memoryPool.addListener(memoryPoolListener);
}

private void onMemoryReserved(MemoryPool memoryPool)
Expand Down Expand Up @@ -190,18 +180,10 @@ private void scheduleRevoking()
private synchronized void runMemoryRevoking()
{
if (checkPending.getAndSet(false)) {
Collection<SqlTask> allTasks = null;
for (MemoryPool memoryPool : memoryPools) {
if (!memoryRevokingNeeded(memoryPool)) {
continue;
}

if (allTasks == null) {
allTasks = requireNonNull(currentTasksSupplier.get());
}

requestMemoryRevoking(memoryPool, allTasks);
if (!memoryRevokingNeeded(memoryPool)) {
return;
}
requestMemoryRevoking(memoryPool, requireNonNull(currentTasksSupplier.get()));
}
}

Expand All @@ -211,7 +193,7 @@ private void requestMemoryRevoking(MemoryPool memoryPool, Collection<SqlTask> al
List<SqlTask> runningTasksInPool = findRunningTasksInMemoryPool(allTasks, memoryPool);
remainingBytesToRevoke -= getMemoryAlreadyBeingRevoked(runningTasksInPool, remainingBytesToRevoke);
if (remainingBytesToRevoke > 0) {
requestRevoking(memoryPool.getId(), runningTasksInPool, remainingBytesToRevoke);
requestRevoking(runningTasksInPool, remainingBytesToRevoke);
}
}

Expand Down Expand Up @@ -256,7 +238,7 @@ public Long mergeResults(List<Long> childrenResults)
return currentRevoking;
}

private void requestRevoking(MemoryPoolId memoryPoolId, List<SqlTask> sqlTasks, long remainingBytesToRevoke)
private void requestRevoking(List<SqlTask> sqlTasks, long remainingBytesToRevoke)
{
VoidTraversingQueryContextVisitor<AtomicLong> visitor = new VoidTraversingQueryContextVisitor<>()
{
Expand All @@ -277,7 +259,7 @@ public Void visitOperatorContext(OperatorContext operatorContext, AtomicLong rem
long revokedBytes = operatorContext.requestMemoryRevoking();
if (revokedBytes > 0) {
remainingBytesToRevoke.addAndGet(-revokedBytes);
log.debug("memoryPool=%s: requested revoking %s; remaining %s", memoryPoolId, revokedBytes, remainingBytesToRevoke.get());
log.debug("requested revoking %s; remaining %s", revokedBytes, remainingBytesToRevoke.get());
}
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import io.trino.execution.QueryTracker.TrackedQuery;
import io.trino.execution.StateMachine.StateChangeListener;
import io.trino.execution.warnings.WarningCollector;
import io.trino.memory.VersionedMemoryPoolId;
import io.trino.server.BasicQueryInfo;
import io.trino.server.protocol.Slug;
import io.trino.spi.type.Type;
Expand Down Expand Up @@ -62,10 +61,6 @@ public interface QueryExecution

DataSize getTotalMemoryReservation();

VersionedMemoryPoolId getMemoryPool();

void setMemoryPool(VersionedMemoryPoolId poolId);

void start();

void cancelQuery();
Expand Down
10 changes: 0 additions & 10 deletions core/trino-main/src/main/java/io/trino/execution/QueryInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import io.trino.spi.TrinoWarning;
import io.trino.spi.eventlistener.RoutineInfo;
import io.trino.spi.eventlistener.TableInfo;
import io.trino.spi.memory.MemoryPoolId;
import io.trino.spi.resourcegroups.QueryType;
import io.trino.spi.resourcegroups.ResourceGroupId;
import io.trino.spi.security.SelectedRole;
Expand All @@ -51,7 +50,6 @@ public class QueryInfo
private final QueryId queryId;
private final SessionRepresentation session;
private final QueryState state;
private final MemoryPoolId memoryPool;
private final boolean scheduled;
private final URI self;
private final List<String> fieldNames;
Expand Down Expand Up @@ -87,7 +85,6 @@ public QueryInfo(
@JsonProperty("queryId") QueryId queryId,
@JsonProperty("session") SessionRepresentation session,
@JsonProperty("state") QueryState state,
@JsonProperty("memoryPool") MemoryPoolId memoryPool,
@JsonProperty("scheduled") boolean scheduled,
@JsonProperty("self") URI self,
@JsonProperty("fieldNames") List<String> fieldNames,
Expand Down Expand Up @@ -145,7 +142,6 @@ public QueryInfo(
this.queryId = queryId;
this.session = session;
this.state = state;
this.memoryPool = requireNonNull(memoryPool, "memoryPool is null");
this.scheduled = scheduled;
this.self = self;
this.fieldNames = ImmutableList.copyOf(fieldNames);
Expand Down Expand Up @@ -195,12 +191,6 @@ public QueryState getState()
return state;
}

@JsonProperty
public MemoryPoolId getMemoryPool()
{
return memoryPool;
}

@JsonProperty
public boolean isScheduled()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import io.trino.execution.QueryExecution.QueryOutputInfo;
import io.trino.execution.StateMachine.StateChangeListener;
import io.trino.execution.warnings.WarningCollector;
import io.trino.memory.VersionedMemoryPoolId;
import io.trino.metadata.Metadata;
import io.trino.operator.BlockedReason;
import io.trino.operator.OperatorStats;
Expand Down Expand Up @@ -89,7 +88,6 @@
import static io.trino.execution.QueryState.TERMINAL_QUERY_STATES;
import static io.trino.execution.QueryState.WAITING_FOR_RESOURCES;
import static io.trino.execution.StageInfo.getAllStages;
import static io.trino.memory.LocalMemoryManager.GENERAL_POOL;
import static io.trino.server.DynamicFilterService.DynamicFiltersStats;
import static io.trino.spi.StandardErrorCode.NOT_FOUND;
import static io.trino.spi.StandardErrorCode.USER_CANCELED;
Expand All @@ -112,8 +110,6 @@ public class QueryStateMachine
private final Metadata metadata;
private final QueryOutputManager outputManager;

private final AtomicReference<VersionedMemoryPoolId> memoryPool = new AtomicReference<>(new VersionedMemoryPoolId(GENERAL_POOL, 0));

private final AtomicLong currentUserMemory = new AtomicLong();
private final AtomicLong peakUserMemory = new AtomicLong();

Expand Down Expand Up @@ -403,7 +399,6 @@ public BasicQueryInfo getBasicQueryInfo(Optional<BasicStageStats> rootStage)
session.toSessionRepresentation(),
Optional.of(resourceGroup),
state,
memoryPool.get().getId(),
stageStats.isScheduled(),
self,
query,
Expand Down Expand Up @@ -440,7 +435,6 @@ QueryInfo getQueryInfo(Optional<StageInfo> rootStage)
queryId,
session.toSessionRepresentation(),
state,
memoryPool.get().getId(),
isScheduled,
self,
outputManager.getQueryOutputInfo().map(QueryOutputInfo::getColumnNames).orElse(ImmutableList.of()),
Expand Down Expand Up @@ -638,16 +632,6 @@ private QueryStats getQueryStats(Optional<StageInfo> rootStage)
operatorStatsSummary.build());
}

public VersionedMemoryPoolId getMemoryPool()
{
return memoryPool.get();
}

public void setMemoryPool(VersionedMemoryPoolId memoryPool)
{
this.memoryPool.set(requireNonNull(memoryPool, "memoryPool is null"));
}

public void addOutputInfoListener(Consumer<QueryOutputInfo> listener)
{
outputManager.addOutputInfoListener(listener);
Expand Down Expand Up @@ -1105,7 +1089,6 @@ public void pruneQueryInfo()
queryInfo.getQueryId(),
queryInfo.getSession(),
queryInfo.getState(),
getMemoryPool().getId(),
queryInfo.isScheduled(),
queryInfo.getSelf(),
queryInfo.getFieldNames(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import io.trino.execution.scheduler.policy.ExecutionPolicy;
import io.trino.execution.warnings.WarningCollector;
import io.trino.failuredetector.FailureDetector;
import io.trino.memory.VersionedMemoryPoolId;
import io.trino.metadata.TableHandle;
import io.trino.operator.ForScheduler;
import io.trino.server.BasicQueryInfo;
Expand Down Expand Up @@ -276,18 +275,6 @@ public Slug getSlug()
return slug;
}

@Override
public VersionedMemoryPoolId getMemoryPool()
{
return stateMachine.getMemoryPool();
}

@Override
public void setMemoryPool(VersionedMemoryPoolId poolId)
{
stateMachine.setMemoryPool(poolId);
}

@Override
public DataSize getUserMemoryReservation()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ private QueryContext createQueryContext(
queryId,
maxQueryUserMemoryPerNode,
maxQueryMemoryPerTask,
localMemoryManager.getGeneralPool(),
localMemoryManager.getMemoryPool(),
gcMonitor,
taskNotificationExecutor,
driverYieldExecutor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class ClusterMemoryLeakDetector

/**
* @param queryInfoSupplier All queries that the coordinator knows about.
* @param queryMemoryReservations The memory reservations of queries in the GENERAL cluster memory pool.
* @param queryMemoryReservations The memory reservations of queries in the cluster memory pool.
*/
void checkForMemoryLeaks(Supplier<List<BasicQueryInfo>> queryInfoSupplier, Map<QueryId, Long> queryMemoryReservations)
{
Expand Down
Loading

0 comments on commit cd2cdda

Please sign in to comment.