Skip to content

Commit

Permalink
Factorize some code
Browse files Browse the repository at this point in the history
  • Loading branch information
lemaitre-aneo committed Nov 28, 2023
1 parent 63d7136 commit 5030418
Showing 1 changed file with 96 additions and 87 deletions.
183 changes: 96 additions & 87 deletions Common/src/Storage/TaskLifeCycleHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,29 +209,59 @@ public static async Task FinalizeTaskCreation(ITaskTable t
return;
}

async Task TransferOwnership()
var prepareTaskDependencies = PrepareTaskDependencies(taskTable,
resultTable,
taskRequests,
sessionId,
logger,
cancellationToken);

// Transfer ownership while dependencies are in preparation
if (!parentTaskId.Equals(sessionId))
{
if (!parentTaskId.Equals(sessionId))
{
var parentExpectedOutputKeys = (await taskTable.GetTaskExpectedOutputKeys(parentTaskId,
cancellationToken)
.ConfigureAwait(false)).ToHashSet();
var taskDataModels =
taskRequests.Select(request => new IResultTable.ChangeResultOwnershipRequest(request.ExpectedOutputKeys.Where(id => parentExpectedOutputKeys.Contains(id)),
request.TaskId));
await resultTable.ChangeResultOwnership(sessionId,
parentTaskId,
taskDataModels,
cancellationToken)
.ConfigureAwait(false);
}
var parentExpectedOutputKeys = (await taskTable.GetTaskExpectedOutputKeys(parentTaskId,
cancellationToken)
.ConfigureAwait(false)).ToHashSet();
var taskDataModels =
taskRequests.Select(request => new IResultTable.ChangeResultOwnershipRequest(request.ExpectedOutputKeys.Where(id => parentExpectedOutputKeys.Contains(id)),
request.TaskId));
await resultTable.ChangeResultOwnership(sessionId,
parentTaskId,
taskDataModels,
cancellationToken)
.ConfigureAwait(false);
}

var transferOwnership = TransferOwnership();

await EnqueueReadyTasks(taskTable,
pushQueueStorage,
await prepareTaskDependencies.ConfigureAwait(false),
cancellationToken)
.ConfigureAwait(false);
}

/// <summary>
/// Collect and record all the task dependencies specified in the <paramref name="taskRequests" />,
/// and return all the tasks that are ready to be enqueued.
/// </summary>
/// <param name="taskTable">Interface to manage task states</param>
/// <param name="resultTable">Interface to manage result states</param>
/// <param name="taskRequests">Tasks requests to finalize</param>
/// <param name="sessionId">Session Id of the completed results</param>
/// <param name="logger">Logger used to produce logs</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Queue messages for ready tasks
/// </returns>
private static async Task<ICollection<MessageData>> PrepareTaskDependencies(ITaskTable taskTable,
IResultTable resultTable,
ICollection<TaskCreationRequest> taskRequests,
string sessionId,
ILogger logger,
CancellationToken cancellationToken)
{
var allDependencies = new HashSet<string>();

// Get all the results that are a dependency of at least one task
foreach (var request in taskRequests)
{
allDependencies.UnionWith(request.DataDependencies);
Expand All @@ -247,6 +277,7 @@ await resultTable.ChangeResultOwnership(sessionId,
allDependencies.Remove(resultId);
}

// Build the mapping between tasks and their dependencies
var taskDependencies = taskRequests.ToDictionary(request => request.TaskId,
request =>
{
Expand Down Expand Up @@ -290,52 +321,13 @@ await taskTable.RemoveRemainingDataDependenciesAsync(taskDependencies.Keys,
cancellationToken)
.ConfigureAwait(false);

var readyTasks = new Dictionary<string, List<MessageData>>();

foreach (var request in taskRequests)
{
var taskId = request.TaskId;
if (taskDependencies[taskId]
.Any(resultId => allDependencies.Contains(resultId)))
{
continue;
}

if (readyTasks.TryGetValue(request.Options.PartitionId,
out var msgsData))
{
msgsData.Add(new MessageData(request.TaskId,
sessionId,
request.Options));
}
else
{
readyTasks.Add(request.Options.PartitionId,
new List<MessageData>
{
new(request.TaskId,
sessionId,
request.Options),
});
}
}

await transferOwnership.ConfigureAwait(false);

if (readyTasks.Any())
{
foreach (var item in readyTasks)
{
await pushQueueStorage.PushMessagesAsync(item.Value,
item.Key,
cancellationToken)
.ConfigureAwait(false);
await taskTable.FinalizeTaskCreation(item.Value.Select(data => data.TaskId)
.ToList(),
cancellationToken)
.ConfigureAwait(false);
}
}
// Return all the tasks that are ready and shall be enqueued
return taskRequests.Where(request => !taskDependencies[request.TaskId]
.Any(resultId => allDependencies.Contains(resultId)))
.Select(request => new MessageData(request.TaskId,
sessionId,
request.Options))
.AsICollection();
}

/// <summary>
Expand Down Expand Up @@ -391,36 +383,53 @@ await taskTable.RemoveRemainingDataDependenciesAsync(dependentTasks,
// Find all tasks whose dependencies are now complete in order to start them.
// Multiple agents can see the same task as ready and will try to start it multiple times.
// This is benign as it will be handled during dequeue with message deduplication.
var groups = (await taskTable.FindTasksAsync(data => dependentTasks.Contains(data.TaskId) && data.Status == TaskStatus.Creating &&
data.RemainingDataDependencies == new Dictionary<string, bool>(),
data => new
{
data.TaskId,
data.SessionId,
data.Options,
data.Options.PartitionId,
data.Options.Priority,
},
cancellationToken)
.ToListAsync(cancellationToken)
.ConfigureAwait(false)).GroupBy(data => (data.PartitionId, data.Priority));
var readyTasks = await taskTable.FindTasksAsync(data => dependentTasks.Contains(data.TaskId) && data.Status == TaskStatus.Creating &&
data.RemainingDataDependencies == new Dictionary<string, bool>(),
data => new MessageData(data.TaskId,
data.SessionId,
data.Options),
cancellationToken)
.ToListAsync(cancellationToken)
.ConfigureAwait(false);

await EnqueueReadyTasks(taskTable,
pushQueueStorage,
readyTasks,
cancellationToken)
.ConfigureAwait(false);
}

foreach (var group in groups)
/// <summary>
/// Enqueue all the messages that are ready for enqueueing, and mark them as enqueued in the task table
/// </summary>
/// <param name="taskTable">Interface to manage task states</param>
/// <param name="pushQueueStorage">Interface to push tasks in the queue</param>
/// <param name="messages">Messages to enqueue</param>
/// <param name="cancellationToken">Token used to cancel the execution of the method</param>
/// <returns>
/// Task representing the asynchronous execution of the method
/// </returns>
private static async Task EnqueueReadyTasks(ITaskTable taskTable,
IPushQueueStorage pushQueueStorage,
ICollection<MessageData> messages,
CancellationToken cancellationToken)
{
if (!messages.Any())
{
var ids = group.Select(data => data.TaskId)
.ToList();
return;
}

var msgsData = group.Select(data => new MessageData(data.TaskId,
data.SessionId,
data.Options));
await pushQueueStorage.PushMessagesAsync(msgsData,
foreach (var group in messages.GroupBy(msg => (msg.Options.PartitionId, msg.Options.Priority)))
{
await pushQueueStorage.PushMessagesAsync(group,
group.Key.PartitionId,
cancellationToken)
.ConfigureAwait(false);

await taskTable.FinalizeTaskCreation(ids,
cancellationToken)
.ConfigureAwait(false);
}

await taskTable.FinalizeTaskCreation(messages.Select(task => task.TaskId)
.AsICollection(),
cancellationToken)
.ConfigureAwait(false);
}
}

0 comments on commit 5030418

Please sign in to comment.