diff --git a/Common/src/Storage/TaskLifeCycleHelper.cs b/Common/src/Storage/TaskLifeCycleHelper.cs
index 16b31f9ef..513c63f4f 100644
--- a/Common/src/Storage/TaskLifeCycleHelper.cs
+++ b/Common/src/Storage/TaskLifeCycleHelper.cs
@@ -209,29 +209,59 @@ public static async Task FinalizeTaskCreation(ITaskTable t
return;
}
- async Task TransferOwnership()
+ var prepareTaskDependencies = PrepareTaskDependencies(taskTable,
+ resultTable,
+ taskRequests,
+ sessionId,
+ logger,
+ cancellationToken);
+
+ // Transfer ownership while dependencies are in preparation
+ if (!parentTaskId.Equals(sessionId))
{
- if (!parentTaskId.Equals(sessionId))
- {
- var parentExpectedOutputKeys = (await taskTable.GetTaskExpectedOutputKeys(parentTaskId,
- cancellationToken)
- .ConfigureAwait(false)).ToHashSet();
- var taskDataModels =
- taskRequests.Select(request => new IResultTable.ChangeResultOwnershipRequest(request.ExpectedOutputKeys.Where(id => parentExpectedOutputKeys.Contains(id)),
- request.TaskId));
- await resultTable.ChangeResultOwnership(sessionId,
- parentTaskId,
- taskDataModels,
- cancellationToken)
- .ConfigureAwait(false);
- }
+ var parentExpectedOutputKeys = (await taskTable.GetTaskExpectedOutputKeys(parentTaskId,
+ cancellationToken)
+ .ConfigureAwait(false)).ToHashSet();
+ var taskDataModels =
+ taskRequests.Select(request => new IResultTable.ChangeResultOwnershipRequest(request.ExpectedOutputKeys.Where(id => parentExpectedOutputKeys.Contains(id)),
+ request.TaskId));
+ await resultTable.ChangeResultOwnership(sessionId,
+ parentTaskId,
+ taskDataModels,
+ cancellationToken)
+ .ConfigureAwait(false);
}
- var transferOwnership = TransferOwnership();
-
+ await EnqueueReadyTasks(taskTable,
+ pushQueueStorage,
+ await prepareTaskDependencies.ConfigureAwait(false),
+ cancellationToken)
+ .ConfigureAwait(false);
+ }
+ ///
+ /// Collect and record all the task dependencies specified in the ,
+ /// and return all the tasks that are ready to be enqueued.
+ ///
+ /// Interface to manage task states
+ /// Interface to manage result states
+ /// Tasks requests to finalize
+ /// Session Id of the completed results
+ /// Logger used to produce logs
+ /// Token used to cancel the execution of the method
+ ///
+ /// Queue messages for ready tasks
+ ///
+ private static async Task> PrepareTaskDependencies(ITaskTable taskTable,
+ IResultTable resultTable,
+ ICollection taskRequests,
+ string sessionId,
+ ILogger logger,
+ CancellationToken cancellationToken)
+ {
var allDependencies = new HashSet();
+ // Get all the results that are a dependency of at least one task
foreach (var request in taskRequests)
{
allDependencies.UnionWith(request.DataDependencies);
@@ -247,6 +277,7 @@ await resultTable.ChangeResultOwnership(sessionId,
allDependencies.Remove(resultId);
}
+ // Build the mapping between tasks and their dependencies
var taskDependencies = taskRequests.ToDictionary(request => request.TaskId,
request =>
{
@@ -290,52 +321,13 @@ await taskTable.RemoveRemainingDataDependenciesAsync(taskDependencies.Keys,
cancellationToken)
.ConfigureAwait(false);
- var readyTasks = new Dictionary>();
-
- foreach (var request in taskRequests)
- {
- var taskId = request.TaskId;
- if (taskDependencies[taskId]
- .Any(resultId => allDependencies.Contains(resultId)))
- {
- continue;
- }
-
- if (readyTasks.TryGetValue(request.Options.PartitionId,
- out var msgsData))
- {
- msgsData.Add(new MessageData(request.TaskId,
- sessionId,
- request.Options));
- }
- else
- {
- readyTasks.Add(request.Options.PartitionId,
- new List
- {
- new(request.TaskId,
- sessionId,
- request.Options),
- });
- }
- }
-
- await transferOwnership.ConfigureAwait(false);
-
- if (readyTasks.Any())
- {
- foreach (var item in readyTasks)
- {
- await pushQueueStorage.PushMessagesAsync(item.Value,
- item.Key,
- cancellationToken)
- .ConfigureAwait(false);
- await taskTable.FinalizeTaskCreation(item.Value.Select(data => data.TaskId)
- .ToList(),
- cancellationToken)
- .ConfigureAwait(false);
- }
- }
+ // Return all the tasks that are ready and shall be enqueued
+ return taskRequests.Where(request => !taskDependencies[request.TaskId]
+ .Any(resultId => allDependencies.Contains(resultId)))
+ .Select(request => new MessageData(request.TaskId,
+ sessionId,
+ request.Options))
+ .AsICollection();
}
///
@@ -391,36 +383,53 @@ await taskTable.RemoveRemainingDataDependenciesAsync(dependentTasks,
// Find all tasks whose dependencies are now complete in order to start them.
// Multiple agents can see the same task as ready and will try to start it multiple times.
// This is benign as it will be handled during dequeue with message deduplication.
- var groups = (await taskTable.FindTasksAsync(data => dependentTasks.Contains(data.TaskId) && data.Status == TaskStatus.Creating &&
- data.RemainingDataDependencies == new Dictionary(),
- data => new
- {
- data.TaskId,
- data.SessionId,
- data.Options,
- data.Options.PartitionId,
- data.Options.Priority,
- },
- cancellationToken)
- .ToListAsync(cancellationToken)
- .ConfigureAwait(false)).GroupBy(data => (data.PartitionId, data.Priority));
+ var readyTasks = await taskTable.FindTasksAsync(data => dependentTasks.Contains(data.TaskId) && data.Status == TaskStatus.Creating &&
+ data.RemainingDataDependencies == new Dictionary(),
+ data => new MessageData(data.TaskId,
+ data.SessionId,
+ data.Options),
+ cancellationToken)
+ .ToListAsync(cancellationToken)
+ .ConfigureAwait(false);
+
+ await EnqueueReadyTasks(taskTable,
+ pushQueueStorage,
+ readyTasks,
+ cancellationToken)
+ .ConfigureAwait(false);
+ }
- foreach (var group in groups)
+ ///
+ /// Enqueue all the messages that are ready for enqueueing, and mark them as enqueued in the task table
+ ///
+ /// Interface to manage task states
+ /// Interface to push tasks in the queue
+ /// Messages to enqueue
+ /// Token used to cancel the execution of the method
+ ///
+ /// Task representing the asynchronous execution of the method
+ ///
+ private static async Task EnqueueReadyTasks(ITaskTable taskTable,
+ IPushQueueStorage pushQueueStorage,
+ ICollection messages,
+ CancellationToken cancellationToken)
+ {
+ if (!messages.Any())
{
- var ids = group.Select(data => data.TaskId)
- .ToList();
+ return;
+ }
- var msgsData = group.Select(data => new MessageData(data.TaskId,
- data.SessionId,
- data.Options));
- await pushQueueStorage.PushMessagesAsync(msgsData,
+ foreach (var group in messages.GroupBy(msg => (msg.Options.PartitionId, msg.Options.Priority)))
+ {
+ await pushQueueStorage.PushMessagesAsync(group,
group.Key.PartitionId,
cancellationToken)
.ConfigureAwait(false);
-
- await taskTable.FinalizeTaskCreation(ids,
- cancellationToken)
- .ConfigureAwait(false);
}
+
+ await taskTable.FinalizeTaskCreation(messages.Select(task => task.TaskId)
+ .AsICollection(),
+ cancellationToken)
+ .ConfigureAwait(false);
}
}