diff --git a/Common/src/Storage/TaskLifeCycleHelper.cs b/Common/src/Storage/TaskLifeCycleHelper.cs index 16b31f9ef..513c63f4f 100644 --- a/Common/src/Storage/TaskLifeCycleHelper.cs +++ b/Common/src/Storage/TaskLifeCycleHelper.cs @@ -209,29 +209,59 @@ public static async Task FinalizeTaskCreation(ITaskTable t return; } - async Task TransferOwnership() + var prepareTaskDependencies = PrepareTaskDependencies(taskTable, + resultTable, + taskRequests, + sessionId, + logger, + cancellationToken); + + // Transfer ownership while dependencies are in preparation + if (!parentTaskId.Equals(sessionId)) { - if (!parentTaskId.Equals(sessionId)) - { - var parentExpectedOutputKeys = (await taskTable.GetTaskExpectedOutputKeys(parentTaskId, - cancellationToken) - .ConfigureAwait(false)).ToHashSet(); - var taskDataModels = - taskRequests.Select(request => new IResultTable.ChangeResultOwnershipRequest(request.ExpectedOutputKeys.Where(id => parentExpectedOutputKeys.Contains(id)), - request.TaskId)); - await resultTable.ChangeResultOwnership(sessionId, - parentTaskId, - taskDataModels, - cancellationToken) - .ConfigureAwait(false); - } + var parentExpectedOutputKeys = (await taskTable.GetTaskExpectedOutputKeys(parentTaskId, + cancellationToken) + .ConfigureAwait(false)).ToHashSet(); + var taskDataModels = + taskRequests.Select(request => new IResultTable.ChangeResultOwnershipRequest(request.ExpectedOutputKeys.Where(id => parentExpectedOutputKeys.Contains(id)), + request.TaskId)); + await resultTable.ChangeResultOwnership(sessionId, + parentTaskId, + taskDataModels, + cancellationToken) + .ConfigureAwait(false); } - var transferOwnership = TransferOwnership(); - + await EnqueueReadyTasks(taskTable, + pushQueueStorage, + await prepareTaskDependencies.ConfigureAwait(false), + cancellationToken) + .ConfigureAwait(false); + } + /// + /// Collect and record all the task dependencies specified in the , + /// and return all the tasks that are ready to be enqueued. + /// + /// Interface to manage task states + /// Interface to manage result states + /// Tasks requests to finalize + /// Session Id of the completed results + /// Logger used to produce logs + /// Token used to cancel the execution of the method + /// + /// Queue messages for ready tasks + /// + private static async Task> PrepareTaskDependencies(ITaskTable taskTable, + IResultTable resultTable, + ICollection taskRequests, + string sessionId, + ILogger logger, + CancellationToken cancellationToken) + { var allDependencies = new HashSet(); + // Get all the results that are a dependency of at least one task foreach (var request in taskRequests) { allDependencies.UnionWith(request.DataDependencies); @@ -247,6 +277,7 @@ await resultTable.ChangeResultOwnership(sessionId, allDependencies.Remove(resultId); } + // Build the mapping between tasks and their dependencies var taskDependencies = taskRequests.ToDictionary(request => request.TaskId, request => { @@ -290,52 +321,13 @@ await taskTable.RemoveRemainingDataDependenciesAsync(taskDependencies.Keys, cancellationToken) .ConfigureAwait(false); - var readyTasks = new Dictionary>(); - - foreach (var request in taskRequests) - { - var taskId = request.TaskId; - if (taskDependencies[taskId] - .Any(resultId => allDependencies.Contains(resultId))) - { - continue; - } - - if (readyTasks.TryGetValue(request.Options.PartitionId, - out var msgsData)) - { - msgsData.Add(new MessageData(request.TaskId, - sessionId, - request.Options)); - } - else - { - readyTasks.Add(request.Options.PartitionId, - new List - { - new(request.TaskId, - sessionId, - request.Options), - }); - } - } - - await transferOwnership.ConfigureAwait(false); - - if (readyTasks.Any()) - { - foreach (var item in readyTasks) - { - await pushQueueStorage.PushMessagesAsync(item.Value, - item.Key, - cancellationToken) - .ConfigureAwait(false); - await taskTable.FinalizeTaskCreation(item.Value.Select(data => data.TaskId) - .ToList(), - cancellationToken) - .ConfigureAwait(false); - } - } + // Return all the tasks that are ready and shall be enqueued + return taskRequests.Where(request => !taskDependencies[request.TaskId] + .Any(resultId => allDependencies.Contains(resultId))) + .Select(request => new MessageData(request.TaskId, + sessionId, + request.Options)) + .AsICollection(); } /// @@ -391,36 +383,53 @@ await taskTable.RemoveRemainingDataDependenciesAsync(dependentTasks, // Find all tasks whose dependencies are now complete in order to start them. // Multiple agents can see the same task as ready and will try to start it multiple times. // This is benign as it will be handled during dequeue with message deduplication. - var groups = (await taskTable.FindTasksAsync(data => dependentTasks.Contains(data.TaskId) && data.Status == TaskStatus.Creating && - data.RemainingDataDependencies == new Dictionary(), - data => new - { - data.TaskId, - data.SessionId, - data.Options, - data.Options.PartitionId, - data.Options.Priority, - }, - cancellationToken) - .ToListAsync(cancellationToken) - .ConfigureAwait(false)).GroupBy(data => (data.PartitionId, data.Priority)); + var readyTasks = await taskTable.FindTasksAsync(data => dependentTasks.Contains(data.TaskId) && data.Status == TaskStatus.Creating && + data.RemainingDataDependencies == new Dictionary(), + data => new MessageData(data.TaskId, + data.SessionId, + data.Options), + cancellationToken) + .ToListAsync(cancellationToken) + .ConfigureAwait(false); + + await EnqueueReadyTasks(taskTable, + pushQueueStorage, + readyTasks, + cancellationToken) + .ConfigureAwait(false); + } - foreach (var group in groups) + /// + /// Enqueue all the messages that are ready for enqueueing, and mark them as enqueued in the task table + /// + /// Interface to manage task states + /// Interface to push tasks in the queue + /// Messages to enqueue + /// Token used to cancel the execution of the method + /// + /// Task representing the asynchronous execution of the method + /// + private static async Task EnqueueReadyTasks(ITaskTable taskTable, + IPushQueueStorage pushQueueStorage, + ICollection messages, + CancellationToken cancellationToken) + { + if (!messages.Any()) { - var ids = group.Select(data => data.TaskId) - .ToList(); + return; + } - var msgsData = group.Select(data => new MessageData(data.TaskId, - data.SessionId, - data.Options)); - await pushQueueStorage.PushMessagesAsync(msgsData, + foreach (var group in messages.GroupBy(msg => (msg.Options.PartitionId, msg.Options.Priority))) + { + await pushQueueStorage.PushMessagesAsync(group, group.Key.PartitionId, cancellationToken) .ConfigureAwait(false); - - await taskTable.FinalizeTaskCreation(ids, - cancellationToken) - .ConfigureAwait(false); } + + await taskTable.FinalizeTaskCreation(messages.Select(task => task.TaskId) + .AsICollection(), + cancellationToken) + .ConfigureAwait(false); } }