Skip to content

Commit

Permalink
Limit memory used by fault tolerant scheduler on coordinator
Browse files Browse the repository at this point in the history
  • Loading branch information
arhimondr authored and losipiuk committed Feb 9, 2022
1 parent 099ea40 commit 0862aaa
Show file tree
Hide file tree
Showing 10 changed files with 556 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
"query.remote-task.max-consecutive-error-count"})
public class QueryManagerConfig
{
public static final long AVAILABLE_HEAP_MEMORY = Runtime.getRuntime().maxMemory();

private int scheduleSplitBatchSize = 1000;
private int minScheduleSplitBatchSize = 100;
private int maxConcurrentQueries = 1000;
Expand Down Expand Up @@ -80,6 +82,7 @@ public class QueryManagerConfig

private DataSize faultTolerantExecutionTargetTaskInputSize = DataSize.of(1, GIGABYTE);
private int faultTolerantExecutionTargetTaskSplitCount = 16;
private DataSize faultTolerantExecutionTaskDescriptorStorageMaxMemory = DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.15));

@Min(1)
public int getScheduleSplitBatchSize()
Expand Down Expand Up @@ -471,4 +474,18 @@ public QueryManagerConfig setFaultTolerantExecutionTargetTaskSplitCount(int faul
this.faultTolerantExecutionTargetTaskSplitCount = faultTolerantExecutionTargetTaskSplitCount;
return this;
}

@NotNull
public DataSize getFaultTolerantExecutionTaskDescriptorStorageMaxMemory()
{
return faultTolerantExecutionTaskDescriptorStorageMaxMemory;
}

@Config("fault-tolerant-execution-task-descriptor-storage-max-memory")
@ConfigDescription("Maximum amount of memory to be used to store task descriptors for fault tolerant queries on coordinator")
public QueryManagerConfig setFaultTolerantExecutionTaskDescriptorStorageMaxMemory(DataSize faultTolerantExecutionTaskDescriptorStorageMaxMemory)
{
this.faultTolerantExecutionTaskDescriptorStorageMaxMemory = faultTolerantExecutionTaskDescriptorStorageMaxMemory;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.trino.execution.scheduler.NodeScheduler;
import io.trino.execution.scheduler.SplitSchedulerStats;
import io.trino.execution.scheduler.SqlQueryScheduler;
import io.trino.execution.scheduler.TaskDescriptorStorage;
import io.trino.execution.scheduler.TaskSourceFactory;
import io.trino.execution.scheduler.policy.ExecutionPolicy;
import io.trino.execution.warnings.WarningCollector;
Expand Down Expand Up @@ -121,6 +122,7 @@ public class SqlQueryExecution
private final TaskManager coordinatorTaskManager;
private final ExchangeManagerRegistry exchangeManagerRegistry;
private final TaskSourceFactory taskSourceFactory;
private final TaskDescriptorStorage taskDescriptorStorage;

private SqlQueryExecution(
PreparedQuery preparedQuery,
Expand Down Expand Up @@ -149,7 +151,8 @@ private SqlQueryExecution(
TypeAnalyzer typeAnalyzer,
TaskManager coordinatorTaskManager,
ExchangeManagerRegistry exchangeManagerRegistry,
TaskSourceFactory taskSourceFactory)
TaskSourceFactory taskSourceFactory,
TaskDescriptorStorage taskDescriptorStorage)
{
try (SetThreadName ignored = new SetThreadName("Query-%s", stateMachine.getQueryId())) {
this.slug = requireNonNull(slug, "slug is null");
Expand Down Expand Up @@ -207,6 +210,7 @@ private SqlQueryExecution(
this.coordinatorTaskManager = requireNonNull(coordinatorTaskManager, "coordinatorTaskManager is null");
this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null");
this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null");
this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null");
}
}

Expand Down Expand Up @@ -521,7 +525,8 @@ private void planDistribution(PlanRoot plan)
splitSourceFactory,
coordinatorTaskManager,
exchangeManagerRegistry,
taskSourceFactory);
taskSourceFactory,
taskDescriptorStorage);

queryScheduler.set(scheduler);

Expand Down Expand Up @@ -709,6 +714,7 @@ public static class SqlQueryExecutionFactory
private final TaskManager coordinatorTaskManager;
private final ExchangeManagerRegistry exchangeManagerRegistry;
private final TaskSourceFactory taskSourceFactory;
private final TaskDescriptorStorage taskDescriptorStorage;

@Inject
SqlQueryExecutionFactory(
Expand All @@ -734,7 +740,8 @@ public static class SqlQueryExecutionFactory
TypeAnalyzer typeAnalyzer,
TaskManager coordinatorTaskManager,
ExchangeManagerRegistry exchangeManagerRegistry,
TaskSourceFactory taskSourceFactory)
TaskSourceFactory taskSourceFactory,
TaskDescriptorStorage taskDescriptorStorage)
{
requireNonNull(config, "config is null");
this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null");
Expand All @@ -760,6 +767,7 @@ public static class SqlQueryExecutionFactory
this.coordinatorTaskManager = requireNonNull(coordinatorTaskManager, "coordinatorTaskManager is null");
this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null");
this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null");
this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null");
}

@Override
Expand Down Expand Up @@ -800,7 +808,8 @@ public QueryExecution createQueryExecution(
typeAnalyzer,
coordinatorTaskManager,
exchangeManagerRegistry,
taskSourceFactory);
taskSourceFactory,
taskDescriptorStorage);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ public class FaultTolerantStageScheduler
private final FailureDetector failureDetector;
private final TaskSourceFactory taskSourceFactory;
private final NodeAllocator nodeAllocator;
private final TaskDescriptorStorage taskDescriptorStorage;

private final TaskLifecycleListener taskLifecycleListener;
// empty when the results are consumed via a direct exchange
Expand All @@ -114,8 +115,6 @@ public class FaultTolerantStageScheduler
@GuardedBy("this")
private TaskSource taskSource;
@GuardedBy("this")
private final Map<Integer, TaskDescriptor> partitionToTaskDescriptorMap = new HashMap<>();
@GuardedBy("this")
private final Map<Integer, ExchangeSinkHandle> partitionToExchangeSinkHandleMap = new HashMap<>();
@GuardedBy("this")
private final Multimap<Integer, RemoteTask> partitionToRemoteTaskMap = ArrayListMultimap.create();
Expand All @@ -124,6 +123,8 @@ public class FaultTolerantStageScheduler
@GuardedBy("this")
private final Map<TaskId, InternalNode> runningNodes = new HashMap<>();
@GuardedBy("this")
private final Set<Integer> allPartitions = new HashSet<>();
@GuardedBy("this")
private final Queue<Integer> queuedPartitions = new ArrayDeque<>();
@GuardedBy("this")
private final Set<Integer> finishedPartitions = new HashSet<>();
Expand All @@ -141,6 +142,7 @@ public FaultTolerantStageScheduler(
FailureDetector failureDetector,
TaskSourceFactory taskSourceFactory,
NodeAllocator nodeAllocator,
TaskDescriptorStorage taskDescriptorStorage,
TaskLifecycleListener taskLifecycleListener,
Optional<Exchange> sinkExchange,
Optional<int[]> sinkBucketToPartitionMap,
Expand All @@ -156,6 +158,7 @@ public FaultTolerantStageScheduler(
this.failureDetector = requireNonNull(failureDetector, "failureDetector is null");
this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null");
this.nodeAllocator = requireNonNull(nodeAllocator, "nodeAllocator is null");
this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null");
this.taskLifecycleListener = requireNonNull(taskLifecycleListener, "taskLifecycleListener is null");
this.sinkExchange = requireNonNull(sinkExchange, "sinkExchange is null");
this.sinkBucketToPartitionMap = requireNonNull(sinkBucketToPartitionMap, "sinkBucketToPartitionMap is null");
Expand Down Expand Up @@ -227,7 +230,8 @@ public synchronized void schedule()
List<TaskDescriptor> tasks = taskSource.getMoreTasks();
for (TaskDescriptor task : tasks) {
queuedPartitions.add(task.getPartitionId());
partitionToTaskDescriptorMap.put(task.getPartitionId(), task);
allPartitions.add(task.getPartitionId());
taskDescriptorStorage.put(stage.getStageId(), task);
sinkExchange.ifPresent(exchange -> {
ExchangeSinkHandle exchangeSinkHandle = exchange.addSink(task.getPartitionId());
partitionToExchangeSinkHandleMap.put(task.getPartitionId(), exchangeSinkHandle);
Expand All @@ -243,7 +247,12 @@ public synchronized void schedule()
}

int partition = queuedPartitions.peek();
TaskDescriptor taskDescriptor = requireNonNull(partitionToTaskDescriptorMap.get(partition), () -> "task descriptor missing for partition: %s" + partition);
Optional<TaskDescriptor> taskDescriptorOptional = taskDescriptorStorage.get(stage.getStageId(), partition);
if (taskDescriptorOptional.isEmpty()) {
// query has been terminated
return;
}
TaskDescriptor taskDescriptor = taskDescriptorOptional.get();

if (acquireNodeFuture == null) {
acquireNodeFuture = nodeAllocator.acquire(taskDescriptor.getNodeRequirements());
Expand Down Expand Up @@ -326,7 +335,7 @@ public synchronized boolean isFinished()
taskSource != null &&
taskSource.isFinished() &&
queuedPartitions.isEmpty() &&
finishedPartitions.containsAll(partitionToTaskDescriptorMap.keySet());
finishedPartitions.containsAll(allPartitions);
}

public void cancel()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ public class SqlQueryScheduler
private final SplitSourceFactory splitSourceFactory;
private final ExchangeManagerRegistry exchangeManagerRegistry;
private final TaskSourceFactory taskSourceFactory;
private final TaskDescriptorStorage taskDescriptorStorage;

private final StageManager stageManager;
private final CoordinatorStagesScheduler coordinatorStagesScheduler;
Expand Down Expand Up @@ -224,7 +225,8 @@ public SqlQueryScheduler(
SplitSourceFactory splitSourceFactory,
TaskManager coordinatorTaskManager,
ExchangeManagerRegistry exchangeManagerRegistry,
TaskSourceFactory taskSourceFactory)
TaskSourceFactory taskSourceFactory,
TaskDescriptorStorage taskDescriptorStorage)
{
this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null");
this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null");
Expand All @@ -240,6 +242,7 @@ public SqlQueryScheduler(
this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null");
this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null");
this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null");
this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null");

stageManager = StageManager.create(
queryStateMachine,
Expand Down Expand Up @@ -332,6 +335,7 @@ private synchronized Optional<DistributedStagesScheduler> createDistributedStage
stageManager,
failureDetector,
taskSourceFactory,
taskDescriptorStorage,
exchangeManager,
nodePartitioningManager,
coordinatorStagesScheduler.getTaskLifecycleListener(),
Expand Down Expand Up @@ -1702,6 +1706,7 @@ public static FaultTolerantDistributedStagesScheduler create(
StageManager stageManager,
FailureDetector failureDetector,
TaskSourceFactory taskSourceFactory,
TaskDescriptorStorage taskDescriptorStorage,
ExchangeManager exchangeManager,
NodePartitioningManager nodePartitioningManager,
TaskLifecycleListener coordinatorTaskLifecycleListener,
Expand All @@ -1710,6 +1715,13 @@ public static FaultTolerantDistributedStagesScheduler create(
SplitSchedulerStats schedulerStats,
NodeScheduler nodeScheduler)
{
taskDescriptorStorage.initialize(queryStateMachine.getQueryId());
queryStateMachine.addStateChangeListener(state -> {
if (state.isDone()) {
taskDescriptorStorage.destroy(queryStateMachine.getQueryId());
}
});

DistributedStagesSchedulerStateMachine stateMachine = new DistributedStagesSchedulerStateMachine(queryStateMachine.getQueryId(), scheduledExecutorService);

Session session = queryStateMachine.getSession();
Expand Down Expand Up @@ -1764,6 +1776,7 @@ public static FaultTolerantDistributedStagesScheduler create(
failureDetector,
taskSourceFactory,
nodeAllocator,
taskDescriptorStorage,
taskLifecycleListener,
exchange,
bucketToPartitionCache.apply(fragment.getPartitioningScheme().getPartitioning().getHandle()).getBucketToPartitionMap(),
Expand Down
Loading

0 comments on commit 0862aaa

Please sign in to comment.